Car*_*ear 4 python matrix multidimensional-array jax
我是 JAX 新手,编写 JIT 编译的代码对我来说相当困难。我正在努力实现以下目标:
给定 JAX 中的一个(n,n)数组mat,我想将一个(1,n)或一个(n,1)数组分别添加到原始数组的任意行或列mat。
如果我想添加一个行数组r到第三行,则 numpy 的等价物是:
# if mat is a numpy array
mat[2,:] = mat[2,:] + r
Run Code Online (Sandbox Code Playgroud)
我知道如何更新 JAX 中数组元素的唯一方法是使用array.at[i].set(). 我不确定如何使用它来更新行或列而不显式使用 for 循环。
JAX 数组是不可变的,因此您无法对数组条目进行就地修改。但您可以使用np.ndarray.at语法实现类似的结果。例如,相当于
mat[2,:] = mat[2,:] + r
Run Code Online (Sandbox Code Playgroud)
将会
mat = mat.at[2,:].set(mat[2,:] + r)
Run Code Online (Sandbox Code Playgroud)
但在这种情况下,您可以使用该add方法来提高效率:
mat = mat.at[2:].add(r)
Run Code Online (Sandbox Code Playgroud)
以下是向二维数组添加行和列数组的示例:
import jax.numpy as jnp
mat = jnp.zeros((5, 5))
# Create 2D row & col arrays, as in question
row = jnp.ones(5).reshape(1, 5)
col = jnp.ones(5).reshape(5, 1)
mat = mat.at[1:2, :].add(row)
mat = mat.at[:, 2:3].add(col)
print(mat)
# [[0. 0. 1. 0. 0.]
# [1. 1. 2. 1. 1.]
# [0. 0. 1. 0. 0.]
# [0. 0. 1. 0. 0.]
# [0. 0. 1. 0. 0.]]
Run Code Online (Sandbox Code Playgroud)
有关此问题的更多讨论,请参阅JAX Sharp Bits:就地更新。