JAX 仅在 jit 下的数组切片上应用函数

Fed*_*hin 7 python numpy jax

我正在使用 JAX,并且我想执行类似的操作

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x
Run Code Online (Sandbox Code Playgroud)

这不能在 下执行jit。有没有办法用jax.opsor来做到这一点jax.lax?我想过使用,但我找不到一种不会再次遇到同样问题jax.ops.index_update(x, idx, y)的计算方法。y

jak*_*vdp 6

如果您的索引是静态的,@rvinas之前的答案using效果dynamic_slice很好,但您也可以使用动态索引来完成此操作jnp.where。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
Run Code Online (Sandbox Code Playgroud)