Numpy:索引3D数组,其中最后一个轴的索引存储在2D数组中

Ker*_*ten 8 python arrays numpy python-2.7

我有一个ndarrayshape(z,y,x)含值.我想这个指数阵列的另一个ndarrayshape(y,x)包含我感兴趣的是值的的z-index.

import numpy as np
val_arr = np.arange(27).reshape(3,3,3)
z_indices = np.array([[1,0,2],
                      [0,0,1],
                      [2,0,1]])
Run Code Online (Sandbox Code Playgroud)

由于我的数组相当大,我试图使用它np.take来避免不必要的数组副本,但只是不能用我的头围绕索引三维数组.

我怎么有索引val_arrz_indices在所需的Z轴位置,以获得价值?预期结果将是:

result_arr = np.array([[9,1,20],
                       [3,4,14],
                       [24,7,17]])
Run Code Online (Sandbox Code Playgroud)

Ale*_*ley 7

您可以choose用来进行选择:

>>> z_indices.choose(val_arr)
array([[ 9,  1, 20],
       [ 3,  4, 14],
       [24,  7, 17]])
Run Code Online (Sandbox Code Playgroud)

该功能choose非常有用,但理解起来可能有些棘手.基本上,给定一个array(val_arr),我们可以z_indices沿着第一个轴从每个n维切片中做出一系列选择().

另外:任何花哨的索引操作都将创建一个新数组,而不是原始数据的视图.这是不可能的指标val_arrz_indices没有创建一个全新的阵列.


Dan*_*enz 5

受此线程的启发,使用np.ogrid

y,x = np.ogrid[0:3, 0:3]
print [z_indices, y, x]
[array([[1, 0, 2],
        [0, 0, 1],
        [2, 0, 1]]),
 array([[0],
        [1],
        [2]]),
 array([[0, 1, 2]])]

print val_arr[z_indices, y, x]
[[ 9  1 20]
 [ 3  4 14]
 [24  7 17]]
Run Code Online (Sandbox Code Playgroud)

我不得不承认多维花哨的索引可能会很混乱和令人困惑:)


Div*_*kar 5

具有可读性,np.choose绝对看起来很棒。

如果性能至关重要,您可以计算线性指数,然后使用np.take或使用扁平版本,.ravel()并从中提取这些特定元素val_arr。实现看起来像这样 -

def linidx_take(val_arr,z_indices):

    # Get number of columns and rows in values array
     _,nC,nR = val_arr.shape

     # Get linear indices and thus extract elements with np.take
    idx = nC*nR*z_indices + nR*np.arange(nR)[:,None] + np.arange(nC)
    return np.take(val_arr,idx) # Or val_arr.ravel()[idx]
Run Code Online (Sandbox Code Playgroud)

运行时测试并验证结果 -

基于 Ogrid 的解决方案来自here这些测试的通用版本,如下所示:

In [182]: def ogrid_based(val_arr,z_indices):
     ...:   v_shp = val_arr.shape
     ...:   y,x = np.ogrid[0:v_shp[1], 0:v_shp[2]]
     ...:   return val_arr[z_indices, y, x]
     ...: 
Run Code Online (Sandbox Code Playgroud)

案例#1:较小的数据量

In [183]: val_arr = np.random.rand(30,30,30)
     ...: z_indices = np.random.randint(0,30,(30,30))
     ...: 

In [184]: np.allclose(z_indices.choose(val_arr),ogrid_based(val_arr,z_indices))
Out[184]: True

In [185]: np.allclose(z_indices.choose(val_arr),linidx_take(val_arr,z_indices))
Out[185]: True

In [187]: %timeit z_indices.choose(val_arr)
1000 loops, best of 3: 230 µs per loop

In [188]: %timeit ogrid_based(val_arr,z_indices)
10000 loops, best of 3: 54.1 µs per loop

In [189]: %timeit linidx_take(val_arr,z_indices)
10000 loops, best of 3: 30.3 µs per loop
Run Code Online (Sandbox Code Playgroud)

案例#2:更大的数据量

In [191]: val_arr = np.random.rand(300,300,300)
     ...: z_indices = np.random.randint(0,300,(300,300))
     ...: 

In [192]: z_indices.choose(val_arr) # Seems like there is some limitation here with bigger arrays.
Traceback (most recent call last):

  File "<ipython-input-192-10c3bb600361>", line 1, in <module>
    z_indices.choose(val_arr)

ValueError: Need between 2 and (32) array objects (inclusive).


In [194]: np.allclose(linidx_take(val_arr,z_indices),ogrid_based(val_arr,z_indices))
Out[194]: True

In [195]: %timeit ogrid_based(val_arr,z_indices)
100 loops, best of 3: 3.67 ms per loop

In [196]: %timeit linidx_take(val_arr,z_indices)
100 loops, best of 3: 2.04 ms per loop
Run Code Online (Sandbox Code Playgroud)


Rob*_*rto 5

如果你有 numpy >= 1.15.0 你可以使用numpy.take_along_axis. 在你的情况下:

result_array = numpy.take_along_axis(val_arr, z_indices.reshape((3,3,1)), axis=2)
Run Code Online (Sandbox Code Playgroud)

这应该会在一行简洁的代码中为您提供您想要的结果。请注意索引数组的大小。它需要与您的维度数量相同val_arr(并且前两个维度的大小相同)。