如何沿批量维度广播 numpy 索引?

Sco*_*ott 5 python numpy multidimensional-array matrix-indexing array-broadcasting

例如,np.array([[1,2],[3,4]])[np.triu_indices(2)]具有 shape (3,),是上三角条目的扁平列表。但是,如果我有一批 2x2 矩阵:

foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)
Run Code Online (Sandbox Code Playgroud)

我想获得每个矩阵的上三角索引,最简单的尝试是:

foo[:,np.triu_indices(2)]
Run Code Online (Sandbox Code Playgroud)

然而,这个对象实际上是有形状的(与我们批量提取上三角条目时所期望的相反)(30,2,3,2)(30,3)

我们如何沿着批量维度广播元组索引?

Div*_*kar 4

获取元组并使用它们来索引最后两个暗淡 -

r,c = np.triu_indices(2)
out = foo[:,r,c]
Run Code Online (Sandbox Code Playgroud)

或者,单行代码Ellipsis适用于3D2D数组 -

foo[(Ellipsis,)+np.triu_indices(2)]
Run Code Online (Sandbox Code Playgroud)

它同样适用于2D数组 -

out = foo[r,c] # foo as 2D input array
Run Code Online (Sandbox Code Playgroud)

遮蔽方式

3D阵列案例

我们还可以使用掩码作为masking基础方式 -

foo[:,~np.tri(2,k=-1, dtype=bool)]
Run Code Online (Sandbox Code Playgroud)

二维数组案例

foo[~np.tri(2,k=-1, dtype=bool)]
Run Code Online (Sandbox Code Playgroud)

  • 雅啊,丹克!!为了让它不那么神秘,我们还可以写:`foo[(..., *np.triu_indices(2))]` (2认同)