本文https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138关于智能权重初始化使用语法
x@w
Run Code Online (Sandbox Code Playgroud)
表示张量(/矩阵)乘法。我以前没有看到过这一点,而是认为我们需要将其“拼写出来”为:
torch.mm(x, w.t())
Run Code Online (Sandbox Code Playgroud)
使用以前的(更好的)语法需要什么?那篇文章没有显示他们正在使用的完整导入集。
只有Python 3.5及以上版本可以使用这个“@”语法。以下是等效的:
a = torch.rand(2,2)
b = torch.rand(2,2)
c = a.mm(b)
print(c)
c = torch.mm(a, b)
print(c)
c = torch.matmul(a, b)
print(c)
c = a @ b # python > 3.5+
print(c)
Run Code Online (Sandbox Code Playgroud)
输出:
tensor([[0.2675, 0.8140],
[0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
[0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
[0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
[0.0415, 0.1644]])
Run Code Online (Sandbox Code Playgroud)
我喜欢使用mm
矩阵到矩阵乘法和mv
矩阵到向量乘法的语法。
为了获得转置矩阵,我喜欢使用简单的a.T
语法。
还要补充一件事:
a = torch.rand(2,2,2)
b = torch.rand(2,2,2)
c = torch.matmul(a, b)
print(c)
c = a @ b # python > 3.5+
print(c)
Run Code Online (Sandbox Code Playgroud)
输出:
tensor([[[0.2951, 0.3021],
[0.8663, 1.0430]],
[[0.2674, 1.3792],
[0.0895, 0.9703]]])
tensor([[[0.2951, 0.3021],
[0.8663, 1.0430]],
[[0.2674, 1.3792],
[0.0895, 0.9703]]])
Run Code Online (Sandbox Code Playgroud)
mm
不能用于rank>2(3 级或以上的张量)。因此,如果您使用更大的排名,请仅使用这些:matmul
或@
符号。