mut*_*oid 3 python jit python-3.x jax
当 JIT 函数的输入结构基本保持不变(除了一个轴具有不同数量的元素之外)时,是否可以避免重新编译 JIT 函数?
import jax
@jax.jit
def f(x):
print('recompiling')
return (x + 10) * 100
a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't
Run Code Online (Sandbox Code Playgroud)
要求:pip install jax、jaxlib
不,当您调用具有不同形状的数组的函数时,无法避免重新编译。从根本上来说,JAX 为静态形状的输入和输出编译函数,并且使用新形状的数组调用 JIT 编译的函数将始终触发重新编译。
目前正在进行一些放宽此要求的工作(在 JAX 的 github 存储库中搜索“动态形状”),但目前没有此类 API 可用。
| 归档时间: |
|
| 查看次数: |
1069 次 |
| 最近记录: |