torch.einsum 如何执行这个 4D 张量乘法?

anu*_*rag 0 python matrix-multiplication pytorch tensor

我遇到了一个用于torch.einsum计算张量乘法的代码。我能够理解低阶张量的工作原理,但是不能理解 4D 张量的工作原理,如下所示:

import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])
Run Code Online (Sandbox Code Playgroud)

我需要以下方面的帮助:

  1. 这里执行的操作是什么(解释矩阵如何相乘/转置等)?
  2. torch.einsum在这种情况下实际上有好处吗?

Aks*_*gal 6

(如果您只想详细了解 einsum 中涉及的步骤,请跳至 tl;dr 部分)

\n

我将尝试einsum逐步解释此示例的工作原理,但torch.einsum我不会使用 ,而是使用numpy.einsum(文档),它的作用完全相同,但总的来说,我对它更满意。尽管如此,同样的步骤也发生在 torch 上。

\n

让我们在 NumPy 中重写上面的代码 -

\n
import numpy as np\n\na = np.random.random((3, 5, 2, 10))\nb = np.random.random((3, 4, 2, 10))\nc = np.einsum(\'nxhd,nyhd->nhxy\', a,b)\nc.shape\n\n#(3, 2, 5, 4)\n
Run Code Online (Sandbox Code Playgroud)\n
\n

一步一步 np.einsum

\n

Einsum 由 3个步骤组成:multiplysumtranspose

\n

让我们看看我们的尺寸。我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10)我们需要(3, 2, 5, 4)基于\'nxhd,nyhd->nhxy\'

\n

1. 乘法

\n

我们不必担心n,x,y,h,d轴的顺序,只需担心是否要保留它们或删除(减少)它们。把它们写成表格,看看我们如何安排我们的尺寸 -

\n
        ## Multiply ##\n       n   x   y   h   d\n      --------------------\na  ->  3   5       2   10\nb  ->  3       4   2   10\nc1 ->  3   5   4   2   10\n
Run Code Online (Sandbox Code Playgroud)\n

x为了使和axis之间的广播乘法得到y(x, y)我们必须在正确的位置添加一个新轴,然后相乘。

\n
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)\nb1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)\n\nc1 = a1*b1\nc1.shape\n\n#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)\n
Run Code Online (Sandbox Code Playgroud)\n

2.求和/减少

\n

接下来,我们要将最后一个轴减小 10。这将为我们提供尺寸(n,x,y,h)

\n
          ## Reduce ##\n        n   x   y   h   d\n       --------------------\nc1  ->  3   5   4   2   10\nc2  ->  3   5   4   2\n
Run Code Online (Sandbox Code Playgroud)\n

这很简单。让我们np.sum重新做一下axis=-1

\n
c2 = np.sum(c1, axis=-1)\nc2.shape\n\n#(3,5,4,2)  #<-- (n, x, y, h)\n
Run Code Online (Sandbox Code Playgroud)\n

3. 转置

\n

最后一步是使用转置重新排列轴。我们可以用np.transpose这个。np.transpose(0,3,1,2)基本上将第 3 个轴放在第 0 个轴之后,并推动第 1 和第 2 个轴。所以,(n,x,y,h)变成(n,h,x,y)

\n
c3 = c2.transpose(0,3,1,2)\nc3.shape\n\n#(3,2,5,4)  #<-- (n, h, x, y)\n
Run Code Online (Sandbox Code Playgroud)\n

4. 最终检查

\n

让我们做最后一次检查,看看 c3 是否与从np.einsum-生成的 c 相同

\n
np.allclose(c,c3)\n\n#True\n
Run Code Online (Sandbox Code Playgroud)\n
\n

TL;博士。

\n

因此,我们实施了\'nxhd , nyhd -> nhxy\'-

\n
input     -> nxhd, nyhd\nmultiply  -> nxyhd      #broadcasting\nsum       -> nxyh       #reduce\ntranspose -> nhxy\n
Run Code Online (Sandbox Code Playgroud)\n
\n

优势

\n

与采取多个步骤相比,优点np.einsum是您可以选择进行计算并使用同一函数执行多个操作所需的“路径”。这可以通过参数来完成optimize,它将优化 einsum 表达式的收缩顺序。

\n

这些操作的非详尽列表可以通过 计算einsum,如下所示以及示例:

\n
    \n
  • 数组的踪迹,numpy.trace.
  • \n
  • 返回对角线,numpy.diag.
  • \n
  • 数组轴求和,numpy.sum.
  • \n
  • 转置和排列,numpy.transpose.
  • \n
  • 矩阵乘法和点积,numpy.matmul numpy.dot.
  • \n
  • 矢量内积和外积,numpy.inner numpy.outer
  • \n
  • 广播、逐元素和标量乘法,numpy.multiply.
  • \n
  • 张量收缩,numpy.tensordot.
  • \n
  • 链式数组操作,低效的计算顺序,numpy.einsum_path.
  • \n
\n
\n

基准测试

\n
%%timeit\nnp.einsum(\'nxhd,nyhd->nhxy\', a,b)\n#8.03 \xc2\xb5s \xc2\xb1 495 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 100000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n
%%timeit\nnp.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)\n#13.7 \xc2\xb5s \xc2\xb1 1.42 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

它表明该np.einsum操作比单独的步骤更快。

\n