JAX(XLA) 与 Numba(LLVM) 减少

Dav*_*idJ 4 python reduce jit numba jax

在计算时间方面,JAX 是否有可能仅减少 CPU 与 Numba 相当?

\n

编译器直接来自conda

\n
$ conda install -c conda-forge numba jax\n
Run Code Online (Sandbox Code Playgroud)\n

这是一个一维 NumPy 数组示例

\n
import numpy as np\nimport numba as nb\nimport jax as jx\n\n@nb.njit\ndef reduce_1d_njit_serial(x):\n    s = 0\n    for xi in x:\n        s += xi\n    return s\n\n@jx.jit\ndef reduce_1d_jax_serial(x):\n    s = 0\n    for xi in x:\n        s += xi\n    return s\n\nN = 2**10\na = np.random.randn(N)\n
Run Code Online (Sandbox Code Playgroud)\n

用于timeit以下

\n
    \n
  1. np.add.reduce(a)给出1.99 \xc2\xb5s ...
  2. \n
  3. reduce_1d_njit_serial(a)给出1.43 \xc2\xb5s ...
  4. \n
  5. reduce_1d_jax_serial(a).item()给出23.5 \xc2\xb5s ...
  6. \n
\n

请注意,jx.numpy.sum(a)和 usingjx.lax.fori_loop给出了可比(稍微慢一些)的比较。次 到reduce_1d_jax_serial.

\n

似乎有更好的方法来制作 XLA 的还原。

\n

编辑:编译时间不包括在内,因为打印语句继续检查结果。

\n

jak*_*vdp 6

当使用 JAX 执行这些类型的微基准测试时,您必须小心确保您正在测量您认为正在测量的内容。JAX 基准测试常见问题解答中有一些提示。通过实施其中一些最佳实践,我发现以下内容适合您的基准:

\n
import jax.numpy as jnp\n\n# Native jit-compiled XLA sum\njit_sum = jx.jit(jnp.sum)\n\n# Avoid including device transfer cost in the benchmarks\na_jax = jnp.array(a)\n\n# Prevent measuring compilation time\n_ = reduce_1d_njit_serial(a)\n_ = reduce_1d_jax_serial(a_jax)\n_ = jit_sum(a_jax)\n\n%timeit np.add.reduce(a)\n# 100000 loops, best of 5: 2.33 \xc2\xb5s per loop\n\n%timeit reduce_1d_njit_serial(a)\n# 1000000 loops, best of 5: 1.43 \xc2\xb5s per loop\n\n%timeit reduce_1d_jax_serial(a_jax).block_until_ready()\n# 100000 loops, best of 5: 6.24 \xc2\xb5s per loop\n\n%timeit jit_sum(a_jax).block_until_ready()\n# 100000 loops, best of 5: 4.37 \xc2\xb5s per loop\n
Run Code Online (Sandbox Code Playgroud)\n

您将看到,对于这些微基准测试,JAX 比 numpy 和 numba 慢几毫秒。那么这是否意味着 JAX 很慢?是和否;您将在 JAX 常见问题解答中找到该问题的更完整答案:JAX 比 numpy 更快吗?。简而言之,这种计算量非常小,差异主要取决于 Python 调度时间,而不是在数组上操作所花费的时间。JAX 项目并没有投入太多精力来优化微基准的 Python 调度:在实践中这并不是那么重要,因为 JAX 中每个程序都会产生一次成本,而不是 numpy 中每个操作一次。

\n