更新 JAX 中二维数组的整行或整列

Car*_*ear 4 python matrix multidimensional-array jax

我是 JAX 新手,编写 JIT 编译的代码对我来说相当困难。我正在努力实现以下目标:

给定 JAX 中的一个(n,n)数组mat,我想将一个(1,n)或一个(n,1)数组分别添加到原始数组的任意行或列mat

如果我想添加一个行数组r到第三行,则 numpy 的等价物是:

# if mat is a numpy array
mat[2,:] = mat[2,:] + r

Run Code Online (Sandbox Code Playgroud)

我知道如何更新 JAX 中数组元素的唯一方法是使用array.at[i].set(). 我不确定如何使用它来更新行或列而不显式使用 for 循环。

jak*_*vdp 6

JAX 数组是不可变的,因此您无法对数组条目进行就地修改。但您可以使用np.ndarray.at语法实现类似的结果。例如,相当于

mat[2,:] = mat[2,:] + r
Run Code Online (Sandbox Code Playgroud)

将会

mat = mat.at[2,:].set(mat[2,:] + r)
Run Code Online (Sandbox Code Playgroud)

但在这种情况下,您可以使用该add方法来提高效率:

mat = mat.at[2:].add(r)
Run Code Online (Sandbox Code Playgroud)

以下是向二维数组添加行和列数组的示例:

import jax.numpy as jnp

mat = jnp.zeros((5, 5))

# Create 2D row & col arrays, as in question
row = jnp.ones(5).reshape(1, 5)
col = jnp.ones(5).reshape(5, 1)

mat = mat.at[1:2, :].add(row)
mat = mat.at[:, 2:3].add(col)

print(mat)
# [[0. 0. 1. 0. 0.]
#  [1. 1. 2. 1. 1.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]]
Run Code Online (Sandbox Code Playgroud)

有关此问题的更多讨论,请参阅JAX Sharp Bits:就地更新