torch.mm、torch.matmul 和 torch.mul 之间有什么区别?

AIB*_*all 18 python-3.x pytorch

torch.mm阅读完 pytorch 文档后,我仍然需要帮助来理解,torch.matmul和之间的区别torch.mul。由于我不完全理解它们,所以我无法简明地解释这一点。

B = torch.tensor([[ 1.1207],
        [-0.3137],
        [ 0.0700],
        [ 0.8378]])

C = torch.tensor([[ 0.5146,  0.1216, -0.5244,  2.2382]])

print(torch.mul(B,C))

print(torch.matmul(B,C))

print(torch.mm(B,C))
Run Code Online (Sandbox Code Playgroud)

所有三个都会产生以下输出(即它们执行矩阵乘法):

tensor([[ 0.5767,  0.1363, -0.5877,  2.5084],
        [-0.1614, -0.0381,  0.1645, -0.7021],
        [ 0.0360,  0.0085, -0.0367,  0.1567],
        [ 0.4311,  0.1019, -0.4393,  1.8752]])
Run Code Online (Sandbox Code Playgroud)
A = torch.tensor([[1.8351,2.1536], [-0.8320,-1.4578]])
B = torch.tensor([[2.9355, 0.3450], [0.5708, 1.9957]])
print(torch.mul(A,B))
print(torch.matmul(A,B))
print(torch.mm(A,B))
Run Code Online (Sandbox Code Playgroud)

产生不同的输出。torch.mm 不再执行矩阵乘法(而是广播并执行逐元素乘法,而其他两个仍然执行矩阵乘法。

tensor([[ 5.3869,  0.7430],
        [-0.4749, -2.9093]])
tensor([[ 6.6162,  4.9310],
        [-3.2744, -3.1964]])
tensor([[ 6.6162,  4.9310],
        [-3.2744, -3.1964]])
Run Code Online (Sandbox Code Playgroud)

输入

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)

Run Code Online (Sandbox Code Playgroud)
tensor1 = 
tensor([[[-0.2267,  0.6311, -0.5689,  1.2712],
         [-0.0241, -0.5362,  0.5481, -0.4534],
         [-0.9773, -0.6842,  0.6927,  0.3363]],

        [[-2.6759,  0.7817,  2.6821,  0.7037],
         [ 0.1804,  0.3938, -1.2235,  0.8729],
         [-1.9873, -0.5030,  0.0945,  0.2688]],

        [[ 0.4244,  1.7350,  0.0558, -0.1861],
         [-0.9063, -0.4737, -0.4284, -0.3883],
         [ 0.4827, -0.2628,  1.0084,  0.2769]],

        [[ 0.2939,  0.4604,  0.8014, -1.8760],
         [ 1.8807,  0.1623,  0.2344, -0.6221],
         [ 1.3964,  3.1637,  0.7889,  0.1195]],

        [[-0.7202,  1.4250,  2.4302,  1.4811],
         [-0.2301,  0.6280,  0.5379,  0.5178],
         [-2.1073, -1.4399, -0.9451,  0.8534]],

        [[ 2.8178, -0.4451, -0.7871, -0.5198],
         [ 0.2825,  1.0692,  0.1559,  1.2945],
         [-0.5828, -1.6287, -2.0661, -0.4107]],

        [[ 0.5077, -0.6349, -0.0160, -0.4477],
         [-0.8070,  0.3746,  1.1852,  0.0351],
         [-0.6454,  1.5877,  0.8561,  1.1021]],

        [[ 0.1191,  1.0116,  0.5807,  1.2105],
         [-0.5403,  1.2404,  1.1532,  0.6537],
         [ 1.4757, -1.3648, -1.7158, -1.0289]],

        [[-0.1326,  0.3715,  0.2429, -0.0794],
         [ 0.3224, -0.3064,  0.1963,  0.7276],
         [ 0.9098,  1.5984, -1.4953,  0.0420]],

        [[ 0.1511,  0.9691, -0.5204,  0.3858],
         [ 0.4566,  1.5482, -0.3401,  0.5960],
         [-0.9998,  0.7198,  0.9286,  0.4498]]])

tensor2 =
tensor([-1.6350,  1.0335, -0.9023,  0.0696])
Run Code Online (Sandbox Code Playgroud)
print(torch.mul(tensor1,tensor2))
print(torch.matmul(tensor1,tensor2))
print(torch.mm(tensor1,tensor2))
Run Code Online (Sandbox Code Playgroud)

输出都不同。我认为torch.mul将矩阵的每4个元素广播并乘以向量,tensor2,即[-0.2267, 0.6311, -0.5689, 1.2712] x tensor 2 逐元素、[-0.0241, -0.5362, 0.5481, -0.4534] x tensor 2逐元素等等。我不明白 torch.matmul在做什么。我认为这与文档的第五个要点有关(如果两个参数......),但我无法理解这一点。https://pytorch.org/docs/stable/ generated/torch.matmul.html

torch.mm我认为无法产生输出的原因是它无法广播(如果我错了,请纠正我)。

tensor([[[ 3.7071e-01,  6.5221e-01,  5.1335e-01,  8.8437e-02],
         [ 3.9400e-02, -5.5417e-01, -4.9460e-01, -3.1539e-02],
         [ 1.5979e+00, -7.0715e-01, -6.2499e-01,  2.3398e-02]],

        [[ 4.3752e+00,  8.0790e-01, -2.4201e+00,  4.8957e-02],
         [-2.9503e-01,  4.0699e-01,  1.1040e+00,  6.0723e-02],
         [ 3.2494e+00, -5.1981e-01, -8.5253e-02,  1.8701e-02]],

        [[-6.9397e-01,  1.7931e+00, -5.0379e-02, -1.2945e-02],
         [ 1.4818e+00, -4.8954e-01,  3.8657e-01, -2.7010e-02],
         [-7.8920e-01, -2.7163e-01, -9.0992e-01,  1.9265e-02]],

        [[-4.8055e-01,  4.7582e-01, -7.2309e-01, -1.3051e-01],
         [-3.0750e+00,  1.6770e-01, -2.1146e-01, -4.3281e-02],
         [-2.2832e+00,  3.2697e+00, -7.1183e-01,  8.3139e-03]],

        [[ 1.1775e+00,  1.4727e+00, -2.1928e+00,  1.0304e-01],
         [ 3.7617e-01,  6.4900e-01, -4.8534e-01,  3.6025e-02],
         [ 3.4455e+00, -1.4882e+00,  8.5277e-01,  5.9369e-02]],

        [[-4.6072e+00, -4.6005e-01,  7.1024e-01, -3.6160e-02],
         [-4.6191e-01,  1.1051e+00, -1.4067e-01,  9.0053e-02],
         [ 9.5283e-01, -1.6833e+00,  1.8643e+00, -2.8571e-02]],

        [[-8.3005e-01, -6.5622e-01,  1.4461e-02, -3.1148e-02],
         [ 1.3195e+00,  3.8716e-01, -1.0694e+00,  2.4421e-03],
         [ 1.0553e+00,  1.6409e+00, -7.7250e-01,  7.6669e-02]],

        [[-1.9477e-01,  1.0455e+00, -5.2398e-01,  8.4209e-02],
         [ 8.8343e-01,  1.2820e+00, -1.0405e+00,  4.5478e-02],
         [-2.4128e+00, -1.4106e+00,  1.5482e+00, -7.1578e-02]],

        [[ 2.1675e-01,  3.8391e-01, -2.1914e-01, -5.5219e-03],
         [-5.2707e-01, -3.1668e-01, -1.7711e-01,  5.0619e-02],
         [-1.4876e+00,  1.6520e+00,  1.3493e+00,  2.9198e-03]],

        [[-2.4706e-01,  1.0015e+00,  4.6955e-01,  2.6842e-02],
         [-7.4663e-01,  1.6001e+00,  3.0685e-01,  4.1462e-02],
         [ 1.6347e+00,  7.4395e-01, -8.3792e-01,  3.1291e-02]]])
tensor([[ 1.6247, -1.0409,  0.2891],
        [ 2.8120,  1.2767,  2.6630],
        [ 1.0358,  1.3518, -1.9515],
        [-0.8583, -3.1620,  0.2830],
        [ 0.5605,  0.5759,  2.8694],
        [-4.3932,  0.5925,  1.1053],
        [-1.5030,  0.6397,  2.0004],
        [ 0.4109,  1.1704, -2.3467],
        [ 0.3760, -0.9702,  1.5165],
        [ 1.2509,  1.2018,  1.5720]])
Run Code Online (Sandbox Code Playgroud)

小智 30

简而言之:

\n
    \n
  • torch.mm- 执行矩阵乘法而不广播- (2D 张量) by (2D 张量)
  • \n
  • torch.mul-通过广播执行元素乘法- (张量)乘以(张量或数字)
  • \n
  • torch.matmul-具有广播的矩阵乘积- (张量)乘以(张量),具有不同的行为,具体取决于张量形状(点积、矩阵乘积、批量矩阵乘积)。
  • \n
\n

一些细节:

\n
    \n
  1. torch.mm- 执行矩阵乘法而不广播
  2. \n
\n

它需要两个 2D 张量,因此n\xc3\x97m * m\xc3\x97p = n\xc3\x97p

\n

从文档https://pytorch.org/docs/stable/ generated/torch.mm.html :

\n
This function does not broadcast. For broadcasting matrix products, see torch.matmul().\n
Run Code Online (Sandbox Code Playgroud)\n
    \n
  1. torch.mul-通过广播执行元素乘法- (张量)乘以(张量或数字)
  2. \n
\n

文档: https: //pytorch.org/docs/stable/ generated/torch.mul.html

\n

torch.mul不执行矩阵乘法。它广播两个张量并执行元素乘法。因此,当您将其与张量 1x4 * 4x1 一起使用时,其工作原理类似于:

\n
This function does not broadcast. For broadcasting matrix products, see torch.matmul().\n
Run Code Online (Sandbox Code Playgroud)\n
tensor([[1., 1., 1.],\n        [2., 2., 2.],\n        [3., 3., 3.]])\ntensor([[  1.,  10., 100.],\n        [  1.,  10., 100.],\n        [  1.,  10., 100.]])\ntensor([[  1.,  10., 100.],\n        [  2.,  20., 200.],\n        [  3.,  30., 300.]])\n
Run Code Online (Sandbox Code Playgroud)\n
    \n
  1. torch.matmul
  2. \n
\n

最好查看官方文档https://pytorch.org/docs/stable/ generated/torch.matmul.html 因为它根据输入张量使用不同的模式。它可以通过广播执行点积、矩阵-矩阵积或批量矩阵积。

\n

至于您关于以下产品的问题:

\n
tensor1 = torch.randn(10, 3, 4)\ntensor2 = torch.randn(4)\n
Run Code Online (Sandbox Code Playgroud)\n

它是产品的批量版本。请检查这个简单的例子来理解:

\n
import torch\n\n# 3x1x3\na = torch.FloatTensor([[[1, 2, 3]], [[3, 4, 5]], [[6, 7, 8]]])\n# 3\nb = torch.FloatTensor([1, 10, 100])\nr1 = torch.matmul(a, b)\n\nr2 = torch.stack((\n    torch.matmul(a[0], b),\n    torch.matmul(a[1], b),\n    torch.matmul(a[2], b),\n))\nassert torch.allclose(r1, r2)\n\n
Run Code Online (Sandbox Code Playgroud)\n

因此它可以看作是跨批次维度堆叠在一起的多个操作。

\n

阅读有关广播的内容也可能有用:

\n

https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

\n