Dav*_*idJ 4 python reduce jit numba jax
在计算时间方面,JAX 是否有可能仅减少 CPU 与 Numba 相当?
\n编译器直接来自conda:
$ conda install -c conda-forge numba jax\nRun Code Online (Sandbox Code Playgroud)\n这是一个一维 NumPy 数组示例
\nimport numpy as np\nimport numba as nb\nimport jax as jx\n\n@nb.njit\ndef reduce_1d_njit_serial(x):\n s = 0\n for xi in x:\n s += xi\n return s\n\n@jx.jit\ndef reduce_1d_jax_serial(x):\n s = 0\n for xi in x:\n s += xi\n return s\n\nN = 2**10\na = np.random.randn(N)\nRun Code Online (Sandbox Code Playgroud)\n用于timeit以下
np.add.reduce(a)给出1.99 \xc2\xb5s ...reduce_1d_njit_serial(a)给出1.43 \xc2\xb5s ...reduce_1d_jax_serial(a).item()给出23.5 \xc2\xb5s ...请注意,jx.numpy.sum(a)和 usingjx.lax.fori_loop给出了可比(稍微慢一些)的比较。次 到reduce_1d_jax_serial.
似乎有更好的方法来制作 XLA 的还原。
\n编辑:编译时间不包括在内,因为打印语句继续检查结果。
\n当使用 JAX 执行这些类型的微基准测试时,您必须小心确保您正在测量您认为正在测量的内容。JAX 基准测试常见问题解答中有一些提示。通过实施其中一些最佳实践,我发现以下内容适合您的基准:
\nimport jax.numpy as jnp\n\n# Native jit-compiled XLA sum\njit_sum = jx.jit(jnp.sum)\n\n# Avoid including device transfer cost in the benchmarks\na_jax = jnp.array(a)\n\n# Prevent measuring compilation time\n_ = reduce_1d_njit_serial(a)\n_ = reduce_1d_jax_serial(a_jax)\n_ = jit_sum(a_jax)\n\n%timeit np.add.reduce(a)\n# 100000 loops, best of 5: 2.33 \xc2\xb5s per loop\n\n%timeit reduce_1d_njit_serial(a)\n# 1000000 loops, best of 5: 1.43 \xc2\xb5s per loop\n\n%timeit reduce_1d_jax_serial(a_jax).block_until_ready()\n# 100000 loops, best of 5: 6.24 \xc2\xb5s per loop\n\n%timeit jit_sum(a_jax).block_until_ready()\n# 100000 loops, best of 5: 4.37 \xc2\xb5s per loop\nRun Code Online (Sandbox Code Playgroud)\n您将看到,对于这些微基准测试,JAX 比 numpy 和 numba 慢几毫秒。那么这是否意味着 JAX 很慢?是和否;您将在 JAX 常见问题解答中找到该问题的更完整答案:JAX 比 numpy 更快吗?。简而言之,这种计算量非常小,差异主要取决于 Python 调度时间,而不是在数组上操作所花费的时间。JAX 项目并没有投入太多精力来优化微基准的 Python 调度:在实践中这并不是那么重要,因为 JAX 中每个程序都会产生一次成本,而不是 numpy 中每个操作一次。
\n| 归档时间: |
|
| 查看次数: |
6249 次 |
| 最近记录: |