ir0*_*098 2 python linear-algebra numpy-einsum einsum
我需要在以下代码中将 einsum 操作替换为标准 numpy 操作:
import numpy as np
a = np.random.rand(128, 16, 8, 32)
b = np.random.rand(256, 8, 32)
output = np.einsum('aijb,rjb->ira', a, b)
Run Code Online (Sandbox Code Playgroud)
在此先感谢您的帮助。
一种选择是对齐到相似的形状并广播乘法,然后sum重新排序轴:
output2 = (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))\n\n# assert np.allclose(output, output2)\nRun Code Online (Sandbox Code Playgroud)\n但这的效率要低得多,因为它会产生一个大的中间体(形状(128, 16, 256, 8, 32)):
# np.einsum('aijb,rjb->ira', a, b)\n68.9 ms \xc2\xb1 23.1 ms per loop (mean \xc2\xb1 std. dev. of 7 runs, 10 loops each)\n\n# (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))\n4.66 s \xc2\xb1 1.65 s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1 loop each)\nRun Code Online (Sandbox Code Playgroud)\n形状:
\n# b[None, None].shape\n#a i r j b\n(1, 1, 256, 8, 32)\n\n# a[:,:,None].shape\n# a i r j b\n(128, 16, 1, 8, 32)\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
117 次 |
| 最近记录: |