使用 pytorch 进行张量乘法的“@”

jav*_*dba 0 python pytorch

本文https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138关于智能权重初始化使用语法

x@w
Run Code Online (Sandbox Code Playgroud)

表示张量(/矩阵)乘法。我以前没有看到过这一点,而是认为我们需要将其“拼写出来”为:

 torch.mm(x, w.t())
Run Code Online (Sandbox Code Playgroud)

使用以前的(更好的)语法需要什么?那篇文章没有显示他们正在使用的完整导入集。

pro*_*sti 5

只有Python 3.5及以上版本可以使用这个“@”语法。以下是等效的:

a = torch.rand(2,2)
b = torch.rand(2,2)

c = a.mm(b)
print(c)

c = torch.mm(a, b)
print(c)

c = torch.matmul(a, b)
print(c)

c = a @ b # python > 3.5+
print(c)
Run Code Online (Sandbox Code Playgroud)

输出:

tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
Run Code Online (Sandbox Code Playgroud)

我喜欢使用mm矩阵到矩阵乘法和mv矩阵到向量乘法的语法。

为了获得转置矩阵,我喜欢使用简单的a.T语法。

还要补充一件事:

a = torch.rand(2,2,2)
b = torch.rand(2,2,2)

c = torch.matmul(a, b)
print(c)

c = a @ b # python > 3.5+
print(c)
Run Code Online (Sandbox Code Playgroud)

输出:

tensor([[[0.2951, 0.3021],
         [0.8663, 1.0430]],

        [[0.2674, 1.3792],
         [0.0895, 0.9703]]])
tensor([[[0.2951, 0.3021],
         [0.8663, 1.0430]],

        [[0.2674, 1.3792],
         [0.0895, 0.9703]]])
Run Code Online (Sandbox Code Playgroud)

mm不能用于rank>2(3 级或以上的张量)。因此,如果您使用更大的排名,请仅使用这些:matmul@符号。