使用提供的输出数组进行 NumPy 花式索引

Vad*_*rov 5 python numpy

我有一个二维数组并用一对数组索引它(实际上我的数组要大得多,有数百万个元素):

a = np.array([[1, 2, 3], [4, 5, 6]])
b = a[[0, 0, 0, 1], [0, 1, 2, 0]]
Run Code Online (Sandbox Code Playgroud)

索引将分配一个新数组。有没有办法用提供的输出数组来做这个索引?

我看着np.takeand np.choose,但似乎它们不适用于一对数组。np.take(..., out=buf)如果我拆散数组并手动构造一维实例,我设法使用,但它会导致更多的内存访问,并且几乎扼杀了消除索引结果分配的改进。

Pie*_*e D 1

很久以前就有人问过这个问题,但以防万一:

如果您的大数组(a在您的问题中)是 C 连续的,那么您可以使用np.lib.stride_tricks.as_strided()而不是np.ravel()问题中的 as 。它返回大数组的视图,无需复制,并且速度非常快。np.take()然后您可以与现有目的地一起使用out=...

但请注意,在您的问题中,索引对本身很大(元素数量是您寻求的输出的两倍)。如果我们将其转换为一维索引(如下np.ravel_multi_index(ix, a.shape)所示),这也是一个与您想要的输出大小相同的新数组,所以最终您可能不会节省太多时间而不是内存。

无论如何,这是以下的一种用法np.lib.stride_tricks.as_strided

n = np.prod(a.shape)
as1d_view = np.lib.stride_tricks.as_strided(a, (n, ), writeable=False)

# then, to index and copy into buf:
np.take(as1d_view, np.ravel_multi_index(ix, a.shape), out=buf)
Run Code Online (Sandbox Code Playgroud)

示例 1:单个大数组,多个索引

# setup
a = np.array([[1, 2, 3], [4, 5, 6]])
indices = [
    (np.array([0, 0, 0, 1]), np.array([0, 1, 2, 0])),
    (np.array([0, 1, 0, 1, 0, 1]), np.array([2, 1, 1, 2, 0, 0])),
    (np.array([0, 0, 1]), np.array([0, 1, 2])),
]

# init
n = np.prod(a.shape)
as1d_view = np.lib.stride_tricks.as_strided(a, (n, ), writeable=False)
m = max(len(ix[0]) for ix in indices)
buf = np.empty(m, dtype=a.dtype)

# loop
for i, ix in enumerate(indices):
    m = len(ix[0])
    np.take(as1d_view, np.ravel_multi_index(ix, a.shape), out=buf[:m])
    print(f'for index {i}, buf={buf[:m]!r}')

# gives:
# for index 0, buf=array([1, 2, 3, 4])
# for index 1, buf=array([3, 5, 2, 6, 1, 4])
# for index 2, buf=array([1, 2, 6])
Run Code Online (Sandbox Code Playgroud)

示例2:单个索引,多个大数组

# setup (presumably the large arrays are obtained one at a time instead...)
a_list = [
    np.array([[1, 2, 3], [4, 5, 6]]),
    np.arange(15).reshape(3, 5),
    np.random.randint(0, 20, (4,4)),
]
ix = np.array([0, 0, 0, 1]), np.array([0, 1, 2, 0])

# init
m = len(ix[0])
buf = np.empty(m, dtype=a.dtype)

# loop
for i, a in enumerate(a_list):
    ix1d = np.ravel_multi_index(ix, a.shape)
    as1d_view = np.lib.stride_tricks.as_strided(a, (n, ), writeable=False)
    np.take(as1d_view, ix1d, out=buf)
    print(f'for array {i}, buf={buf!r}')

# gives:
# for array 0, buf=array([1, 2, 3, 4])
# for array 1, buf=array([0, 1, 2, 5])
# for array 2, buf=array([ 5, 19,  1,  3])
Run Code Online (Sandbox Code Playgroud)