我在下面编写了这段代码,以尝试了解这些钩子发生了什么。
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(10,5)
self.fc2 = nn.Linear(5,1)
self.fc1.register_forward_hook(self._forward_hook)
self.fc1.register_backward_hook(self._backward_hook)
def forward(self, inp):
return self.fc2(self.fc1(inp))
def _forward_hook(self, module, input, output):
print(type(input))
print(len(input))
print(type(output))
print(input[0].shape)
print(output.shape)
print()
def _backward_hook(self, module, grad_input, grad_output):
print(type(grad_input))
print(len(grad_input))
print(type(grad_output))
print(len(grad_output))
print(grad_input[0].shape)
print(grad_input[1].shape)
print(grad_output[0].shape)
print()
model = Model()
out = model(torch.tensor(np.arange(10).reshape(1,1,10), dtype=torch.float32))
out.backward()
Run Code Online (Sandbox Code Playgroud)
产生输出
<class 'tuple'>
1
<class 'torch.Tensor'>
torch.Size([1, 1, 10])
torch.Size([1, 1, 5])
<class 'tuple'>
2
<class 'tuple'>
1
torch.Size([1, 1, 5])
torch.Size([5])
torch.Size([1, 1, 5])
Run Code Online (Sandbox Code Playgroud)
您还可以按照此处的CNN 示例进行操作。事实上,需要理解我的问题的其余部分。
我有几个问题:
我通常认为grad_input
(后钩)应该与output
(前钩)形状相同,因为当我们向后走时,方向是相反的。但 CNN 的例子似乎表明情况并非如此。我还是有点困惑。是哪条路呢?
为什么我的图层上的grad_input[0]
和 是grad_output[0]
相同的形状Linear
?不管我的问题1的答案如何,至少其中一个应该是torch.Size([1, 1, 10])
正确的?
元组的第二个元素是什么grad_input
?在 CNN 案例中,我复制粘贴了示例并执行了print(grad_input[1].size())
输出torch.Size([20, 10, 5, 5])
。所以我认为这是权重的梯度。我也跑了print(grad_input[2].size())
,得到了torch.Size([20])
。所以很明显我正在研究偏差的梯度。但在我的Linear
例子中grad_input
,长度是 2,所以我最多只能访问grad_input[1]
,这似乎给了我偏差的梯度。那么权重的梯度在哪里呢?
总之,在Conv2d
和 “线性”模块的情况下,向后钩子的行为之间存在两个明显的矛盾。这让我完全不知道这个钩子会带来什么。
感谢您的帮助!
Piy*_*ngh 11
我通常认为 grad_input (后向钩子)应该与输出具有相同的形状
grad_input
backward
包含梯度(已调用的任何张量;通常它是进行机器学习时的损失张量,对您来说它只是该层的输出Model
) 。input
所以它的形状与 相同input
。类似地,grad_output
其形状与output
层的形状相同。您引用的 CNN 示例也是如此。
为什么 grad_input[0] 和 grad_output[0] 在我的线性层上具有相同的形状?无论我的问题 1 的答案如何,至少其中之一应该是 torch.Size([1, 1, 10]) 对吗?
理想情况下,grad_input
应该包含层输入的梯度以及层的权重和偏差。如果您对 CNN 示例使用以下后向钩子,您就会看到这种行为:
def _backward_hook(module, grad_input, grad_output):
for i, inp in enumerate(grad_input):
print("Input #", i, inp.shape)
Run Code Online (Sandbox Code Playgroud)
然而,该层不会发生这种情况Linear
。这是因为一个错误。置顶评论如下:
模块钩子实际上是在模块创建的最后一个函数上注册的
所以后端真正可能发生的事情(我的猜测)是它正在计算Y=((W^TX)+b)
。可以看到最后一个操作是添加偏置。因此,对于该操作,有一个形状为 (1,1,5) 的输入,并且偏置项具有形状 (5)。这两个(实际上是梯度)构成了你的 tuple grad_input
。加法的结果(实际上是梯度)存储在grad_output
形状 (1,1,5)中
元组 grad_input 的第二个元素是什么
正如上面所回答的,它只是梯度,无论“层参数”梯度是根据什么计算的;通常是最后一次操作的权重/偏差(无论适用的)。
归档时间: |
|
查看次数: |
12053 次 |
最近记录: |