使用多CPU核心的正确方法是什么jax.pmap?
以下示例在 CPU 核心后端上为 SPMD 创建环境变量,测试 JAX 是否识别设备,并尝试设备锁定。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
import jax as jx
import jax.numpy as jnp
jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2
jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]
def sfunc(x): while True: pass
jx.pmap(sfunc)(jnp.arange(2))
Run Code Online (Sandbox Code Playgroud)
从jupyter内核执行并观察htop发现只有一个核心被锁定
htop当省略前两行并运行时,我收到相同的输出:
$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
Run Code Online (Sandbox Code Playgroud)
替换sfunc为
def sfunc(x): return 2.0*x
Run Code Online (Sandbox Code Playgroud)
并打电话
jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
Run Code Online (Sandbox Code Playgroud)
确实返回一个SharedDeviecArray.
显然我没有正确配置 JAX/XLA …