我正在尝试了解 autograd 在 PyTorch 中的工作原理。在下面的简单程序中,我不明白为什么losswrtW1和的梯度W2是None。据我从文档中了解到,是这样吗?我的意思是,我如何不能对中间节点的损失求导数?谁能解释一下我在这里缺少什么?W1并且W2是不稳定的,因此无法计算梯度。
import torch
import torch.autograd as tau
W = tau.Variable(torch.FloatTensor([[0, 1]]), requires_grad=True)
a = tau.Variable(torch.FloatTensor([[2, 2]]), requires_grad=False)
b = tau.Variable(torch.FloatTensor([[3, 3]]), requires_grad=False)
W1 = W + a * a
W2 = W1 - b * b * b
Z = W2 * W2
print 'W:', W
print 'W1:', W1
print 'W2:', W2
print 'Z:', Z
loss = torch.sum((Z - 3) * (Z - 3))
print 'loss:', loss
# free W gradient buffer in case you are running this cell more than 2 times
if W.grad is not None: W.grad.data.zero_()
loss.backward()
print 'W.grad:', W.grad
# all of them are None
print 'W1.grad:', W1.grad
print 'W2.grad:', W2.grad
print 'a.grad:', a.grad
print 'b.grad:', b.grad
print 'Z.grad:', Z.grad
Run Code Online (Sandbox Code Playgroud)
需要时,中间梯度会累积在 C++ 缓冲区中,但为了节省内存,默认情况下不会保留它们(在 python 对象中公开)。requires_grad=True仅保留设置的叶变量的梯度(因此W在您的示例中)
保留中间梯度的一种方法是注册一个钩子。这项工作的一个钩子是retain_grad()(参见 PR)在你的例子中,如果你写W2.retain_grad(),中间梯度W2将暴露在W2.grad
W1和W2不是易失性的(您可以通过访问它们的volatile属性(即:W1.volatile)来检查),并且不能是易失性的,因为它们不是叶变量(例如W、a和b)。相反,需要计算它们的梯度,参见它们的requires_grad属性。如果只有一个叶子变量 ,则不会构建volatile整个后向图(您可以通过创建 volatile 并查看损失梯度函数来检查)
a = tau.Variable(torch.FloatTensor([[2, 2]]), volatile=True)
# ...
assert loss.grad_fn is None
Run Code Online (Sandbox Code Playgroud)
总结
| 归档时间: |
|
| 查看次数: |
2712 次 |
| 最近记录: |