使用 NumPy 的 sum 时避免循环

JCO*_*idl 5 python arrays numpy

我经常需要对较大的 NumPy 数组的某些行或列进行求和。例如,拿这个数组:

>>> c = np.arange(18).reshape(3, 6)
>>> print(c)
[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]]
Run Code Online (Sandbox Code Playgroud)

假设我只想在行索引为 0 或 2 的地方求和,并且列索引为 0、2、4 或 5。换句话说,我想要子数组的总和

[[ 0  2  4  5]
 [12 14 16 17]]
Run Code Online (Sandbox Code Playgroud)

我通常使用 NumPy 非常有用的ix_方法来做到这一点;例如

>>> np.sum(c[np.ix_([0,2],[0,2,4,5])])
70
Run Code Online (Sandbox Code Playgroud)

到现在为止还挺好。但是,现在假设我有一个不同的数组,e, ,c, ,但有两个前导维度。所以它的形状是 (2,3,3,6) 而不是 (3,6):

e = np.arange(108).reshape(2, 3, 3, 6)
Run Code Online (Sandbox Code Playgroud)

(请注意,我使用的实际数组可能包含任何随机整数;它们不包含像本例那样的连续整数。)

我要做的是对每个行/列组合进行上面的计算。以下适用于这个简单的示例,但对于具有更多维度的较大数组,这可能非常非常慢:

new_sum = np.empty((2,3))
for i in range(2):
   for j in range(3):
      temp_array = e[i,j,:,:]
      new_sum[i,j] = np.sum(temp_array[np.ix_([0,2],[0,2,4,5])])
Run Code Online (Sandbox Code Playgroud)

问题:以上是否可以以更快的方式完成,大概不需要使用循环?

作为脚注,上述结果如下:

>>> print(new_sum)
[[ 70. 214. 358.]
 [502. 646. 790.]]
Run Code Online (Sandbox Code Playgroud)

当然,左上角的 70 和我们之前得到的结果是一样的。

小智 2

您可以创建一个布尔矩阵(掩码),其中包含True您想要保留的值和False您不想要的值。

>>> mask = np.zeros((3,6), dtype='bool')
>>> mask[np.ix_([0,2],[0,2,4,5])] = True
>>> mask
array([[ True, False,  True, False,  True,  True],
       [False, False, False, False, False, False],
       [ True, False,  True, False,  True,  True]])
Run Code Online (Sandbox Code Playgroud)

然后,您可以利用 numpy 数组广播规则将掩码应用于数组并对最后一个维度求和:

>>> new_sum = np.sum(e * mask.reshape(1,1,3,6), axis=(2,3))
>>> new_sum
array([[ 70, 214, 358],
       [502, 646, 790]])
Run Code Online (Sandbox Code Playgroud)

下面是一个小代码,用于比较两个版本在更大矩阵上的性能:

import numpy as np
import time

N, P = 200, 100
e = np.arange(18*N*P).reshape(N, P, 3, 6)

t_start = time.time()
new_sum = np.empty((N,P))
for i in range(N):
   for j in range(P):
      temp_array = e[i,j,:,:]
      new_sum[i,j] = np.sum(temp_array[np.ix_([0,2],[0,2,4,5])])
print(f'Timer 1: {time.time()-t_start}s')

t_start = time.time()
mask = np.zeros((3,6), dtype='bool')
mask[np.ix_([0,2],[0,2,4,5])] = True
new_sum_2 = np.sum(e * mask.reshape(1,1,3,6), axis=(2,3))
print(f'Timer 2: {time.time()-t_start}s')

print('Results are equal!' if np.allclose(new_sum, new_sum_2) else 'Results differ!')
Run Code Online (Sandbox Code Playgroud)

输出:

% python3 script.py
Timer 1: 0.4343228340148926s
Timer 2: 0.002004384994506836s
Results are equal!
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,您在计算时间方面得到了显着的改进。