Val*_*acé 7 multiprocessing pmap tpu jax
我的问题很简单:
我正在一台小型笔记本电脑上进行编码,并且正在使用jax.pmap因为我的代码将在多个 TPU 上运行。我想“假装”拥有多个设备来测试我的代码并尝试不同的事情。
有什么办法可以做到吗?但我怀疑 Jax 内部是否能找到解决方案。谢谢!
您可以通过设置以下环境变量来欺骗单个设备支持的多个 XLA 设备:
$ set XLA_FLAGS="--xla_force_host_platform_device_count=8"
Run Code Online (Sandbox Code Playgroud)
在Python中,你可以这样做
$ set XLA_FLAGS="--xla_force_host_platform_device_count=8"
Run Code Online (Sandbox Code Playgroud)
请注意,当仅存在一个物理设备时,此处的所有“设备”都将由同一线程池支持。这不会提高代码的性能,但对于在单设备机器上测试并行实现的语义很有用。