我在官方 pytorch 教程中得到了这个 nn 结构:
输入 -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> 视图 -> 线性 -> relu -> 线性 -> relu -> 线性 -> MSELoss -> 损失
然后是如何使用变量中的内置 .grad_fn 向后跟踪 grad 的示例。
# Eg:
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
Run Code Online (Sandbox Code Playgroud)
所以我想我可以通过粘贴 next_function[0][0] 9 次来达到 Conv2d 的 grad 对象,因为给定的例子但我得到了索引之外的错误元组。那么如何正确索引这些反向传播对象呢?
小智 7
在运行教程中的以下内容后,在PyTorch CNN教程中:
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
Run Code Online (Sandbox Code Playgroud)
以下代码片段将打印完整图形:
def print_graph(g, level=0):
if g == None: return
print('*'*level*4, g)
for subg in g.next_functions:
print_graph(subg[0], level+1)
print_graph(loss.grad_fn, 0)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1583 次 |
最近记录: |