我是 JAX 新手,编写 JIT 编译的代码对我来说相当困难。我正在努力实现以下目标:
给定 JAX 中的一个(n,n)数组mat,我想将一个(1,n)或一个(n,1)数组分别添加到原始数组的任意行或列mat。
如果我想添加一个行数组r到第三行,则 numpy 的等价物是:
# if mat is a numpy array
mat[2,:] = mat[2,:] + r
Run Code Online (Sandbox Code Playgroud)
我知道如何更新 JAX 中数组元素的唯一方法是使用array.at[i].set(). 我不确定如何使用它来更新行或列而不显式使用 for 循环。