Ker*_*ten 8 python arrays numpy python-2.7
我有一个ndarray
的shape(z,y,x)
含值.我想这个指数阵列的另一个ndarray
的shape(y,x)
包含我感兴趣的是值的的z-index.
import numpy as np
val_arr = np.arange(27).reshape(3,3,3)
z_indices = np.array([[1,0,2],
[0,0,1],
[2,0,1]])
Run Code Online (Sandbox Code Playgroud)
由于我的数组相当大,我试图使用它np.take
来避免不必要的数组副本,但只是不能用我的头围绕索引三维数组.
我怎么有索引val_arr
与z_indices
在所需的Z轴位置,以获得价值?预期结果将是:
result_arr = np.array([[9,1,20],
[3,4,14],
[24,7,17]])
Run Code Online (Sandbox Code Playgroud)
您可以choose
用来进行选择:
>>> z_indices.choose(val_arr)
array([[ 9, 1, 20],
[ 3, 4, 14],
[24, 7, 17]])
Run Code Online (Sandbox Code Playgroud)
该功能choose
非常有用,但理解起来可能有些棘手.基本上,给定一个array(val_arr
),我们可以z_indices
沿着第一个轴从每个n维切片中做出一系列选择().
另外:任何花哨的索引操作都将创建一个新数组,而不是原始数据的视图.这是不可能的指标val_arr
有z_indices
没有创建一个全新的阵列.
y,x = np.ogrid[0:3, 0:3]
print [z_indices, y, x]
[array([[1, 0, 2],
[0, 0, 1],
[2, 0, 1]]),
array([[0],
[1],
[2]]),
array([[0, 1, 2]])]
print val_arr[z_indices, y, x]
[[ 9 1 20]
[ 3 4 14]
[24 7 17]]
Run Code Online (Sandbox Code Playgroud)
我不得不承认多维花哨的索引可能会很混乱和令人困惑:)
具有可读性,np.choose
绝对看起来很棒。
如果性能至关重要,您可以计算线性指数,然后使用np.take
或使用扁平版本,.ravel()
并从中提取这些特定元素val_arr
。实现看起来像这样 -
def linidx_take(val_arr,z_indices):
# Get number of columns and rows in values array
_,nC,nR = val_arr.shape
# Get linear indices and thus extract elements with np.take
idx = nC*nR*z_indices + nR*np.arange(nR)[:,None] + np.arange(nC)
return np.take(val_arr,idx) # Or val_arr.ravel()[idx]
Run Code Online (Sandbox Code Playgroud)
运行时测试并验证结果 -
基于 Ogrid 的解决方案来自here
这些测试的通用版本,如下所示:
In [182]: def ogrid_based(val_arr,z_indices):
...: v_shp = val_arr.shape
...: y,x = np.ogrid[0:v_shp[1], 0:v_shp[2]]
...: return val_arr[z_indices, y, x]
...:
Run Code Online (Sandbox Code Playgroud)
案例#1:较小的数据量
In [183]: val_arr = np.random.rand(30,30,30)
...: z_indices = np.random.randint(0,30,(30,30))
...:
In [184]: np.allclose(z_indices.choose(val_arr),ogrid_based(val_arr,z_indices))
Out[184]: True
In [185]: np.allclose(z_indices.choose(val_arr),linidx_take(val_arr,z_indices))
Out[185]: True
In [187]: %timeit z_indices.choose(val_arr)
1000 loops, best of 3: 230 µs per loop
In [188]: %timeit ogrid_based(val_arr,z_indices)
10000 loops, best of 3: 54.1 µs per loop
In [189]: %timeit linidx_take(val_arr,z_indices)
10000 loops, best of 3: 30.3 µs per loop
Run Code Online (Sandbox Code Playgroud)
案例#2:更大的数据量
In [191]: val_arr = np.random.rand(300,300,300)
...: z_indices = np.random.randint(0,300,(300,300))
...:
In [192]: z_indices.choose(val_arr) # Seems like there is some limitation here with bigger arrays.
Traceback (most recent call last):
File "<ipython-input-192-10c3bb600361>", line 1, in <module>
z_indices.choose(val_arr)
ValueError: Need between 2 and (32) array objects (inclusive).
In [194]: np.allclose(linidx_take(val_arr,z_indices),ogrid_based(val_arr,z_indices))
Out[194]: True
In [195]: %timeit ogrid_based(val_arr,z_indices)
100 loops, best of 3: 3.67 ms per loop
In [196]: %timeit linidx_take(val_arr,z_indices)
100 loops, best of 3: 2.04 ms per loop
Run Code Online (Sandbox Code Playgroud)
如果你有 numpy >= 1.15.0 你可以使用numpy.take_along_axis
. 在你的情况下:
result_array = numpy.take_along_axis(val_arr, z_indices.reshape((3,3,1)), axis=2)
Run Code Online (Sandbox Code Playgroud)
这应该会在一行简洁的代码中为您提供您想要的结果。请注意索引数组的大小。它需要与您的维度数量相同val_arr
(并且前两个维度的大小相同)。