4D numpy 数组上的矩阵乘法

4be*_*ars 5 python arrays numpy matrix

我需要对两个 4D 数组(m 和 n)执行矩阵乘法,m 和 n 的尺寸分别为 2x2x2x2 和 2x3x2x2,这应该会产生 2x3x2x2 数组。经过大量研究(主要在本网站上),似乎可以使用np.einsumnp.tensordot有效地完成此操作,但我无法复制从 Matlab 获得的答案(手动验证)。我了解这些方法(einsum和tensordot)在二维数组上执行矩阵乘法时如何工作(这里清楚地解释了清楚解释),但我无法获得 4D 数组的正确轴索引。显然我\xe2\x80\x99m 缺少一些东西!我的实际问题涉及两个 23x23x3x3 复数数组,但我的测试数组是:

\n\n
a = np.array([[1, 7], [4, 3]]) \nb = np.array([[2, 9], [4, 5]]) \nc = np.array([[3, 6], [1, 0]]) \nd = np.array([[2, 8], [1, 2]]) \ne = np.array([[0, 0], [1, 2]])\nf = np.array([[2, 8], [1, 0]])\n\nm = np.array([[a, b], [c, d]])              # (2,2,2,2)\nn = np.array([[e, f, a], [b, d, c]])        # (2,3,2,2)\n
Run Code Online (Sandbox Code Playgroud)\n\n

我意识到复数可能会带来更多问题,但现在,我只是想了解索引如何与 einsum 和 tensordot 一起使用。我\xe2\x80\x99m追逐的答案是这个2x3x2x2数组:

\n\n
+----+-----------+-----------+-----------+\n|    | 0         | 1         | 2         |\n+====+===========+===========+===========+\n|  0 | [[47 77]  | [[22 42]  | [[44 40]  |\n|    |  [31 67]] |  [27 74]] |  [33 61]] |\n+----+-----------+-----------+-----------+\n|  1 | [[42 70]  | [[24 56]  | [[41 51]  |\n|    |  [10 19]] |  [ 6 20]] |  [ 6 13]] |\n+----+-----------+-----------+-----------+\n
Run Code Online (Sandbox Code Playgroud)\n\n

我最接近的尝试是使用 np.tensordot:

\n\n
mn = np.tensordot(m,n, axes=([1,3],[0,2]))\n
Run Code Online (Sandbox Code Playgroud)\n\n

这给了我一个 2x2x3x2 数组,其中包含正确的数字,但顺序不正确:

\n\n
+----+-----------+-----------+\n|    | 0         | 1         |\n+====+===========+===========+\n|  0 | [[47 77]  | [[31 67]  |\n|    |  [22 42]  |  [24 74]  |\n|    |  [44 40]] |  [33 61]] |\n+----+-----------+-----------+\n|  1 | [[42 70]  | [[10 19]  |\n|    |  [24 56]  |  [ 6 20]  |\n|    |  [41 51]] |  [ 6 13]] |\n+----+-----------+-----------+\n
Run Code Online (Sandbox Code Playgroud)\n\n

我\xe2\x80\x99ve也尝试从这里实现一些解决方案,但没有任何运气。
\n任何有关我如何改进这一点的想法将不胜感激,谢谢

\n

Div*_*kar 3

您可以简单地交换结果上的轴tensordot,这样我们仍然可以利用BLAS基于总和的缩减tensordot-

np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
Run Code Online (Sandbox Code Playgroud)

或者,我们可以交换调用中m和的位置并转置以重新排列所有轴 -ntensordot

np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
Run Code Online (Sandbox Code Playgroud)

通过手动重塑和交换轴,我们也可以引入2D矩阵乘法np.dot,如下所示 -

m0,m1,m2,m3 = m.shape
n0,n1,n2,n3 = n.shape
m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
Run Code Online (Sandbox Code Playgroud)

运行时测试 -

将输入数组缩放为10x形状:

In [85]: m = np.random.rand(20,20,20,20)

In [86]: n = np.random.rand(20,30,20,20)

# @Daniel F's soln with einsum
In [87]: %timeit np.einsum('ijkl,jmln->imkn', m, n)
10 loops, best of 3: 136 ms per loop

In [126]: %timeit np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
100 loops, best of 3: 2.31 ms per loop

In [127]: %timeit np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
100 loops, best of 3: 2.37 ms per loop

In [128]: %%timeit
     ...: m0,m1,m2,m3 = m.shape
     ...: n0,n1,n2,n3 = n.shape
     ...: m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
     ...: n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
     ...: out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
100 loops, best of 3: 2.36 ms per loop
Run Code Online (Sandbox Code Playgroud)