用正常操作替换 einsum

ir0*_*098 2 python linear-algebra numpy-einsum einsum

我需要在以下代码中将 einsum 操作替换为标准 numpy 操作:

import numpy as np
a = np.random.rand(128, 16, 8, 32)
b = np.random.rand(256, 8, 32)
output = np.einsum('aijb,rjb->ira', a, b)
Run Code Online (Sandbox Code Playgroud)

在此先感谢您的帮助。

moz*_*way 5

一种选择是对齐到相似的形状并广播乘法,然后sum重新排序轴:

\n
output2 = (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))\n\n# assert np.allclose(output, output2)\n
Run Code Online (Sandbox Code Playgroud)\n

但这的效率要低得多,因为它会产生一个大的中间体(形状(128, 16, 256, 8, 32)):

\n
# np.einsum('aijb,rjb->ira', a, b)\n68.9 ms \xc2\xb1 23.1 ms per loop (mean \xc2\xb1 std. dev. of 7 runs, 10 loops each)\n\n# (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))\n4.66 s \xc2\xb1 1.65 s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1 loop each)\n
Run Code Online (Sandbox Code Playgroud)\n

形状:

\n
# b[None, None].shape\n#a  i    r  j   b\n(1, 1, 256, 8, 32)\n\n# a[:,:,None].shape\n#  a   i  r  j   b\n(128, 16, 1, 8, 32)\n
Run Code Online (Sandbox Code Playgroud)\n