小编Jim*_*nor的帖子

JAX vmap vs pmap vs Python 多处理

我正在将一些代码从纯 Python 重写为 JAX。我已经达到了这样的程度:在我的旧代码中,我使用 Python 的多处理模块来并行计算单个节点中所有 CPU 核心上的函数,如下所示:

# start pool process 
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes

# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)

# close pool processes to free memory
pool.close()
pool.join()
Run Code Online (Sandbox Code Playgroud)

我知道 JAX 有 vmap 和 pmap,但我不明白它们是否可以直接替代我上面使用的 multiprocessing.pool.map。 …

parallel-processing multiprocessing spmd python-multiprocessing jax

2
推荐指数
1
解决办法
1299
查看次数