blu*_*nox 9 python matrix linear-algebra neural-network pytorch
当在PyTorch中使用双线性层时,我无法绕过计算的方式.
这是一个小例子,我试图弄清楚它是如何工作的:
在:
import torch.nn as nn
B = nn.Bilinear(2, 2, 1)
print(B.weight)
Run Code Online (Sandbox Code Playgroud)
日期:
Parameter containing:
tensor([[[-0.4394, -0.4920],
[ 0.6137, 0.4174]]], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)
我正在通过一个零向量和一个向量.
在:
print(B(torch.ones(2), torch.zeros(2)))
print(B(torch.zeros(2), torch.ones(2)))
Run Code Online (Sandbox Code Playgroud)
日期:
tensor([0.2175], grad_fn=<ThAddBackward>)
tensor([0.2175], grad_fn=<ThAddBackward>)
Run Code Online (Sandbox Code Playgroud)
我尝试以各种方式加权,但我没有得到相同的结果.
提前致谢!
完成的操作nn.Bilinear
是B(x1, x2) = x1*A*x2 + b
(cf doc):
A
存储在 nn.Bilinear.weight
b
存储在 nn.Bilinear.bias
如果考虑(可选)偏差,则应获得预期结果.
import torch
import torch.nn as nn
def manual_bilinear(x1, x2, A, b):
return torch.mm(x1, torch.mm(A, x2)) + b
x_ones = torch.ones(2)
x_zeros = torch.zeros(2)
# ---------------------------
# With Bias:
B = nn.Bilinear(2, 2, 1)
A = B.weight
print(B.bias)
# > tensor([-0.6748], requires_grad=True)
b = B.bias
print(B(x_ones, x_zeros))
# > tensor([-0.6748], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_zeros.view(2, 1), A.squeeze(), b))
# > tensor([[-0.6748]], grad_fn=<ThAddBackward>)
print(B(x_ones, x_ones))
# > tensor([-1.7684], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_ones.view(2, 1), A.squeeze(), b))
# > tensor([[-1.7684]], grad_fn=<ThAddBackward>)
# ---------------------------
# Without Bias:
B = nn.Bilinear(2, 2, 1, bias=False)
A = B.weight
print(B.bias)
# None
b = torch.zeros(1)
print(B(x_ones, x_zeros))
# > tensor([0.], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_zeros.view(2, 1), A.squeeze(), b))
# > tensor([0.], grad_fn=<ThAddBackward>)
print(B(x_ones, x_ones))
# > tensor([-0.7897], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_ones.view(2, 1), A.squeeze(), b))
# > tensor([[-0.7897]], grad_fn=<ThAddBackward>)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1991 次 |
最近记录: |