JAX vmap vs pmap vs Python 多处理

Jim*_*nor 2 parallel-processing multiprocessing spmd python-multiprocessing jax

我正在将一些代码从纯 Python 重写为 JAX。我已经达到了这样的程度:在我的旧代码中,我使用 Python 的多处理模块来并行计算单个节点中所有 CPU 核心上的函数,如下所示:

# start pool process 
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes

# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)

# close pool processes to free memory
pool.close()
pool.join()
Run Code Online (Sandbox Code Playgroud)

我知道 JAX 有 vmap 和 pmap,但我不明白它们是否可以直接替代我上面使用的 multiprocessing.pool.map。

  1. vmap(function,in_axes=0)(inputs)分配给所有可用的 CPU 核心还是什么?
  2. pmap(function,in_axes=0)(inputs)与 vmap 和 multiprocessing.pool.map有什么不同?
  3. 我上面对 multiprocessing.pool.map 的使用是 pmap 的“单程序、多数据 (SPMD)”代码示例吗?
  4. 当我真正这样做时,pmap(function,in_axes=0)(inputs)我收到一个错误 -- ValueError: 编译需要 10 个逻辑设备的计算,但只有 1 个 XLA 设备可用 (num_replicas=10,num_partitions=1) -- 这是什么意思?
  5. 最后,我的用例非常简单:我只想使用单个节点上的部分/全部 CPU 核心(例如,我的 Macbook 上的所有 10 个 CPU 核心)。但我听说过嵌套 pmap(vmap) ——这是否用于在多个连接节点的核心上并行化(例如在超级计算机上)?这更类似于 mpi4py 而不是多处理(后者仅限于单个节点)。

jak*_*vdp 5

  1. vmap(function,in_axes=0)(inputs)分配给所有可用的 CPU 核心还是什么?

不,vmap与并行化无关。它是矢量化变换,而不是并行化变换。在正常操作过程中,JAX 可能通过 XLA 使用多个核心,因此 vmapped 操作也可能会这样做。但 . 中没有明确的并行化vmap

  1. 和 和有什么pmap(function,in_axes=0)(inputs)不同?vmapmultiprocessing.pool.map

pmap在多个 XLA 设备上并行。vmap不是并行化,而是在单个设备上进行矢量化。multiprocessing在多个 Python 进程上并行。

  1. 我上面对 multiprocessing.pool.map 的使用是 pmap 的“单程序、多数据 (SPMD)”代码示例吗?

是的,它可以被描述为跨多个 python 进程的 SPMD。

  1. 当我真正这样做时,pmap(function,in_axes=0)(inputs)我收到一个错误 -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1)-- 这是什么意思?

pmap在多个 XLA 设备上并行,并且您仅配置了一个 XLA 设备,因此无法执行请求的操作。

  1. 最后,我的用例非常简单:我只想使用单个节点上的部分/全部 CPU 核心(例如,我的 Macbook 上的所有 10 个 CPU 核心)。但我听说过嵌套 pmap(vmap) ——这是否用于在多个连接节点的核心上并行化(例如在超级计算机上)?这更类似于 mpi4py 而不是多处理(后者仅限于单个节点)。

是的,我相信它pmap可以用于在多个 CPU 核心上进行计算。是否嵌套vmap无关紧要。请参阅具有多核 CPU 的 JAX pmap

还要注意的是,它jax.pmap已被弃用,取而代之的是较新的jax.shard_map,它对于多设备/多主机计算来说是一种更加灵活的转换。这里有一些信息:https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.htmlhttps://jax.readthedocs.io/en/latest/jep/14273-shard-map.html