anu*_*rag 0 python matrix-multiplication pytorch tensor
我遇到了一个用于torch.einsum计算张量乘法的代码。我能够理解低阶张量的工作原理,但是不能理解 4D 张量的工作原理,如下所示:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
Run Code Online (Sandbox Code Playgroud)
我需要以下方面的帮助:
torch.einsum在这种情况下实际上有好处吗?(如果您只想详细了解 einsum 中涉及的步骤,请跳至 tl;dr 部分)
\n我将尝试einsum逐步解释此示例的工作原理,但torch.einsum我不会使用 ,而是使用numpy.einsum(文档),它的作用完全相同,但总的来说,我对它更满意。尽管如此,同样的步骤也发生在 torch 上。
让我们在 NumPy 中重写上面的代码 -
\nimport numpy as np\n\na = np.random.random((3, 5, 2, 10))\nb = np.random.random((3, 4, 2, 10))\nc = np.einsum(\'nxhd,nyhd->nhxy\', a,b)\nc.shape\n\n#(3, 2, 5, 4)\nRun Code Online (Sandbox Code Playgroud)\nEinsum 由 3个步骤组成:multiply、sumtranspose
让我们看看我们的尺寸。我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10)我们需要(3, 2, 5, 4)基于\'nxhd,nyhd->nhxy\'
我们不必担心n,x,y,h,d轴的顺序,只需担心是否要保留它们或删除(减少)它们。把它们写成表格,看看我们如何安排我们的尺寸 -
## Multiply ##\n n x y h d\n --------------------\na -> 3 5 2 10\nb -> 3 4 2 10\nc1 -> 3 5 4 2 10\nRun Code Online (Sandbox Code Playgroud)\nx为了使和axis之间的广播乘法得到y,(x, y)我们必须在正确的位置添加一个新轴,然后相乘。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)\nb1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)\n\nc1 = a1*b1\nc1.shape\n\n#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)\nRun Code Online (Sandbox Code Playgroud)\n接下来,我们要将最后一个轴减小 10。这将为我们提供尺寸(n,x,y,h)。
## Reduce ##\n n x y h d\n --------------------\nc1 -> 3 5 4 2 10\nc2 -> 3 5 4 2\nRun Code Online (Sandbox Code Playgroud)\n这很简单。让我们np.sum重新做一下axis=-1
c2 = np.sum(c1, axis=-1)\nc2.shape\n\n#(3,5,4,2) #<-- (n, x, y, h)\nRun Code Online (Sandbox Code Playgroud)\n最后一步是使用转置重新排列轴。我们可以用np.transpose这个。np.transpose(0,3,1,2)基本上将第 3 个轴放在第 0 个轴之后,并推动第 1 和第 2 个轴。所以,(n,x,y,h)变成(n,h,x,y)
c3 = c2.transpose(0,3,1,2)\nc3.shape\n\n#(3,2,5,4) #<-- (n, h, x, y)\nRun Code Online (Sandbox Code Playgroud)\n让我们做最后一次检查,看看 c3 是否与从np.einsum-生成的 c 相同
np.allclose(c,c3)\n\n#True\nRun Code Online (Sandbox Code Playgroud)\n因此,我们实施了\'nxhd , nyhd -> nhxy\'-
input -> nxhd, nyhd\nmultiply -> nxyhd #broadcasting\nsum -> nxyh #reduce\ntranspose -> nhxy\nRun Code Online (Sandbox Code Playgroud)\n与采取多个步骤相比,优点np.einsum是您可以选择进行计算并使用同一函数执行多个操作所需的“路径”。这可以通过参数来完成optimize,它将优化 einsum 表达式的收缩顺序。
这些操作的非详尽列表可以通过 计算einsum,如下所示以及示例:
numpy.trace.numpy.diag.numpy.sum.numpy.transpose.numpy.matmul numpy.dot.numpy.inner numpy.outer。numpy.multiply.numpy.tensordot.numpy.einsum_path.%%timeit\nnp.einsum(\'nxhd,nyhd->nhxy\', a,b)\n#8.03 \xc2\xb5s \xc2\xb1 495 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 100000 loops each)\nRun Code Online (Sandbox Code Playgroud)\n%%timeit\nnp.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)\n#13.7 \xc2\xb5s \xc2\xb1 1.42 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100000 loops each)\nRun Code Online (Sandbox Code Playgroud)\n它表明该np.einsum操作比单独的步骤更快。
| 归档时间: |
|
| 查看次数: |
2224 次 |
| 最近记录: |