小编Car*_*ear的帖子

更新 JAX 中二维数组的整行或整列

我是 JAX 新手,编写 JIT 编译的代码对我来说相当困难。我正在努力实现以下目标:

给定 JAX 中的一个(n,n)数组mat,我想将一个(1,n)或一个(n,1)数组分别添加到原始数组的任意行或列mat

如果我想添加一个行数组r到第三行,则 numpy 的等价物是:

# if mat is a numpy array
mat[2,:] = mat[2,:] + r

Run Code Online (Sandbox Code Playgroud)

我知道如何更新 JAX 中数组元素的唯一方法是使用array.at[i].set(). 我不确定如何使用它来更新行或列而不显式使用 for 循环。

python matrix multidimensional-array jax

4
推荐指数
1
解决办法
1369
查看次数

标签 统计

jax ×1

matrix ×1

multidimensional-array ×1

python ×1