numpy多维索引和函数“ take”

Hay*_*row 5 python indexing numpy multidimensional-array

在一周的奇数天,我几乎可以理解numpy中的多维索引。Numpy具有函数“ take”,该函数似乎可以实现我想要的功能,但是额外的好处是,我可以控制如果索引超出范围时会发生什么情况。具体来说,我有一个3维数组作为查询表来询问

lut = np.ones([13,13,13],np.bool)
Run Code Online (Sandbox Code Playgroud)

和一个2x2的3个长向量数组,用作表的索引

arr = np.arange(12).reshape([2,2,3]) % 13 
Run Code Online (Sandbox Code Playgroud)

IIUC,如果我要写的lut[arr]话,arr则被视为2x2x3的数字数组,当这些数字用作索引时,lut它们各自返回13x13的数组。这解释了原因lut[arr].shape is (2, 2, 3, 13, 13)

我可以通过写来做自己想要的

lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] #(is there a better way to write this?)
Run Code Online (Sandbox Code Playgroud)

现在这三个术语的作用就好像它们已被压缩以生成2x2的元组数组并lut[<tuple>]从生成单个元素lut。最终结果是2x2的条目数组lut,正是我想要的。

我已经阅读了'take'功能的文档...

此功能与“奇特”索引(使用数组索引数组)具有相同的作用;但是,如果您需要沿给定轴的元素,则使用起来会更容易。

axis:int,可选
在其上选择值的轴。

也许天真地,我认为设置axis=2我将获得三个值以用作三元组来执行查找,但是实际上

np.take(lut,arr).shape =  (2, 2, 3)
np.take(lut,arr,axis=0).shape =  (2, 2, 3, 13, 13)
np.take(lut,arr,axis=1).shape =  (13, 2, 2, 3, 13)
np.take(lut,arr,axis=2).shape =  (13, 13, 2, 2, 3)
Run Code Online (Sandbox Code Playgroud)

所以很明显我不明白发生了什么。谁能告诉我如何实现我想要的?

Div*_*kar 6

我们可以计算线性指数,然后使用np.take-

np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
Run Code Online (Sandbox Code Playgroud)

如果您对替代方案持开放态度,我们可以将索引数组重新整形为2D,转换为元组,用它索引到数据数组中,给我们1D,可以重新整形为2D-

lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
Run Code Online (Sandbox Code Playgroud)

样品运行 -

In [49]: lut = np.random.randint(11,99,(13,13,13))

In [50]: arr = np.arange(12).reshape([2,2,3])

In [51]: lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] # Original approach
Out[51]: 
array([[41, 21],
       [94, 22]])

In [52]: np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
Out[52]: 
array([[41, 21],
       [94, 22]])

In [53]: lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
Out[53]: 
array([[41, 21],
       [94, 22]])
Run Code Online (Sandbox Code Playgroud)

我们可以避免这种np.take方法的双重转置,就像这样 -

In [55]: np.take(lut, np.ravel_multi_index(arr.transpose(2,0,1), lut.shape))
Out[55]: 
array([[41, 21],
       [94, 22]])
Run Code Online (Sandbox Code Playgroud)

推广到通用维度的多维数组

这可以推广到通用编号的 ndarrays。昏暗的,像这样 -

np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
Run Code Online (Sandbox Code Playgroud)

tuple-based方法应该无需任何更改即可工作。

这是相同的示例运行 -

In [95]: lut = np.random.randint(11,99,(13,13,13,13))

In [96]: arr = np.random.randint(0,13,(2,3,4,4))

In [97]: lut[ arr[:,:,:,0] , arr[:,:,:,1],arr[:,:,:,2],arr[:,:,:,3] ]
Out[97]: 
array([[[95, 11, 40, 75],
        [38, 82, 11, 38],
        [30, 53, 69, 21]],

       [[61, 74, 33, 94],
        [90, 35, 89, 72],
        [52, 64, 85, 22]]])

In [98]: np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
Out[98]: 
array([[[95, 11, 40, 75],
        [38, 82, 11, 38],
        [30, 53, 69, 21]],

       [[61, 74, 33, 94],
        [90, 35, 89, 72],
        [52, 64, 85, 22]]])
Run Code Online (Sandbox Code Playgroud)


Hay*_*row 0

最初的问题是尝试在表中进行查找,但某些索引超出范围,我想在发生这种情况时控制行为。

import numpy as np
lut = np.ones((5,7,11),np.int) # a 3-dimensional lookup table
print("lut.shape = ",lut.shape ) # (5,7,11)

# valid points are in the interior with value 99,
# invalid points are on the faces with value 0
lut[:,:,:] = 0
lut[1:-1,1:-1,1:-1] = 99

# set up an array of indexes with many of them too large or too small
start = -35
arr = np.arange(start,2*11*3+start,1).reshape(2,11,3)

# This solution has the advantage that I can understand what is going on
# and so I can amend it if I need to

# split arr into tuples along axis=2
arrchannels = arr[:,:,0],arr[:,:,1],arr[:,:,2]

# convert into a flat array but clip the values
ravelledarr = np.ravel_multi_index(arrchannels, lut.shape, mode='clip')

# and now turn back into a list of numpy arrays
# (not an array of the original shape )
clippedarr = np.unravel_index( ravelledarr, lut.shape)
print(clippedarr[0].shape,"*",len(clippedarr)) # produces (2, 11) * 3

# and now I can do the lookup with the indexes clipped to fit
print(lut[clippedarr])

# these are more succinct but opaque ways of doing the same
# due to @Divakar and @hjpauli respectively
print( np.take(lut, np.ravel_multi_index(arr.T, lut.shape, mode='clip')).T )
print( lut.flat[np.ravel_multi_index(arr.T, lut.shape, mode='clip')].T )
Run Code Online (Sandbox Code Playgroud)

实际的应用是,我有一个 RGB 图像,其中包含一些带有一些标记的纹理木材,并且我已经识别了其中的一块。我想获取该补丁中的一组像素,并标记整个图像中与其中一个匹配的所有点。256x256x256 存在表太大,因此我对补丁中的像素运行聚类算法,并为每个簇设置存在表(补丁中的颜色通过 rgb 或 hsv 空间形成细长的线程,因此簇周围的框很小)。

我将存在表设置为比需要的稍大,并用 False 填充每个面。

一旦我设置了这些小的存在表,我现在可以通过查找表中的每个像素来测试图像的其余部分是否匹配补丁,并使用剪切来使通常不会映射到表中的像素实际映射到桌子的一个面(并获取值“False”)