我正在使用 JAX,并且我想执行类似的操作
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
Run Code Online (Sandbox Code Playgroud)
这不能在 下执行jit。有没有办法用jax.opsor来做到这一点jax.lax?我想过使用,但我找不到一种不会再次遇到同样问题jax.ops.index_update(x, idx, y)的计算方法。y
如果您的索引是静态的,@rvinas之前的答案using效果dynamic_slice很好,但您也可以使用动态索引来完成此操作jnp.where。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3568 次 |
| 最近记录: |