小编Dav*_*idJ的帖子

具有多核 CPU 的 JAX pmap

使用多CPU核心的正确方法是什么jax.pmap

以下示例在 CPU 核心后端上为 SPMD 创建环境变量,测试 JAX 是否识别设备,并尝试设备锁定。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

import jax as jx
import jax.numpy as jnp

jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2

jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]

def sfunc(x): while True: pass

jx.pmap(sfunc)(jnp.arange(2))
Run Code Online (Sandbox Code Playgroud)

从jupyter内核执行并观察htop发现只有一个核心被锁定

从 jupyter 内核执行

htop当省略前两行并运行时,我收到相同的输出:

$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
Run Code Online (Sandbox Code Playgroud)

替换sfunc

def sfunc(x): return 2.0*x
Run Code Online (Sandbox Code Playgroud)

并打电话

jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
Run Code Online (Sandbox Code Playgroud)

确实返回一个SharedDeviecArray.

显然我没有正确配置 JAX/XLA …

multicore pmap jax

8
推荐指数
1
解决办法
2454
查看次数

JAX(XLA) 与 Numba(LLVM) 减少

在计算时间方面,JAX 是否有可能仅减少 CPU 与 Numba 相当?

\n

编译器直接来自conda

\n
$ conda install -c conda-forge numba jax\n
Run Code Online (Sandbox Code Playgroud)\n

这是一个一维 NumPy 数组示例

\n
import numpy as np\nimport numba as nb\nimport jax as jx\n\n@nb.njit\ndef reduce_1d_njit_serial(x):\n    s = 0\n    for xi in x:\n        s += xi\n    return s\n\n@jx.jit\ndef reduce_1d_jax_serial(x):\n    s = 0\n    for xi in x:\n        s += xi\n    return s\n\nN = 2**10\na = np.random.randn(N)\n
Run Code Online (Sandbox Code Playgroud)\n

用于timeit以下

\n
    \n
  1. np.add.reduce(a)给出1.99 \xc2\xb5s ...
  2. \n
  3. reduce_1d_njit_serial(a)给出1.43 \xc2\xb5s ... …

python reduce jit numba jax

4
推荐指数
1
解决办法
6249
查看次数

标签 统计

jax ×2

jit ×1

multicore ×1

numba ×1

pmap ×1

python ×1

reduce ×1