小编mbs*_*diz的帖子

PyTorch中计算中间节点的梯度

我正在尝试了解 autograd 在 PyTorch 中的工作原理。在下面的简单程序中,我不明白为什么losswrtW1和的梯度W2None据我从文档中了解到,W1并且W2是不稳定的,因此无法计算梯度。是这样吗?我的意思是,我如何不能对中间节点的损失求导数?谁能解释一下我在这里缺少什么?

import torch
import torch.autograd as tau

W = tau.Variable(torch.FloatTensor([[0, 1]]), requires_grad=True)
a = tau.Variable(torch.FloatTensor([[2, 2]]), requires_grad=False)
b = tau.Variable(torch.FloatTensor([[3, 3]]), requires_grad=False)

W1 = W  + a * a
W2 = W1 - b * b * b
Z = W2 * W2

print 'W:', W
print 'W1:', W1
print 'W2:', W2
print 'Z:', Z

loss = torch.sum((Z - 3) * (Z …
Run Code Online (Sandbox Code Playgroud)

pytorch

1
推荐指数
1
解决办法
2712
查看次数

标签 统计

pytorch ×1