矢量化大NumPy乘法

Jul*_*ien 5 python arrays numpy vectorization multiplication

我有兴趣计算一个大的NumPy数组.我有一个A包含大量数字的大型数组.我想计算这些数字的不同组合的总和.数据结构如下:

A = np.random.uniform(0,1, (3743, 1388, 3))
Combinations = np.random.randint(0,3, (306,3))
Final_Product = np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])
Run Code Online (Sandbox Code Playgroud)

我的问题是,如果有更优雅和记忆效率更高的计算方法吗?np.dot()当涉及三维阵列时,我觉得这很令人沮丧.

如果它有帮助,理想的形状Final_Product应该是(3743,306,1388).目前Final_Product形状(306,3743,1388),所以我可以重塑到那里.

Div*_*kar 5

np.dot()除非您涉及可能包括的额外步骤,否则不会给您提供所需的输出reshaping.这是一种vectorized方法np.einsum,一次性完成,没有任何额外的内存开销 -

Final_Product = np.einsum('ijk,lk->lij',A,Combinations)
Run Code Online (Sandbox Code Playgroud)

为了完整性,这里有np.dotreshaping之前讨论的 -

M,N,R = A.shape
Final_Product = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
Run Code Online (Sandbox Code Playgroud)

运行时测试并验证输出 -

In [138]: # Inputs ( smaller version of those listed in question )
     ...: A = np.random.uniform(0,1, (374, 138, 3))
     ...: Combinations = np.random.randint(0,3, (30,3))
     ...: 

In [139]: %timeit np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])
1 loops, best of 3: 324 ms per loop

In [140]: %timeit np.einsum('ijk,lk->lij',A,Combinations)
10 loops, best of 3: 32 ms per loop

In [141]: M,N,R = A.shape

In [142]: %timeit A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
100 loops, best of 3: 15.6 ms per loop

In [143]: Final_Product =np.array([np.sum( A*cb, axis=2)  for cb in Combinations])
     ...: Final_Product2 = np.einsum('ijk,lk->lij',A,Combinations)
     ...: M,N,R = A.shape
     ...: Final_Product3 = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
     ...: 

In [144]: print np.allclose(Final_Product,Final_Product2)
True

In [145]: print np.allclose(Final_Product,Final_Product3)
True
Run Code Online (Sandbox Code Playgroud)


Ale*_*ley 5

而不是dot你可以使用tensordot.您当前的方法相当于:

np.tensordot(A, Combinations, [2, 1]).transpose(2, 0, 1)
Run Code Online (Sandbox Code Playgroud)

请注意transpose最后将轴按正确顺序排列.

比如dot,该tensordot函数可以调用快速BLAS/LAPACK库(如果已安装它们),因此应该对大型数组执行良好.