我正在将一些代码从纯 Python 重写为 JAX。我已经达到了这样的程度:在我的旧代码中,我使用 Python 的多处理模块来并行计算单个节点中所有 CPU 核心上的函数,如下所示:
# start pool process
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes
# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)
# close pool processes to free memory
pool.close()
pool.join()
Run Code Online (Sandbox Code Playgroud)
我知道 JAX 有 vmap 和 pmap,但我不明白它们是否可以直接替代我上面使用的 multiprocessing.pool.map。 …
parallel-processing multiprocessing spmd python-multiprocessing jax