您如何在 pytorch 中正确使用 grad_fn 上的 next_functions[0][0] ?

Ink*_*ay_ 4 pytorch

我在官方 pytorch 教程中得到了这个 nn 结构:

输入 -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> 视图 -> 线性 -> relu -> 线性 -> relu -> 线性 -> MSELoss -> 损失

然后是如何使用变量中的内置 .grad_fn 向后跟踪 grad 的示例。

# Eg: 
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU
Run Code Online (Sandbox Code Playgroud)

所以我想我可以通过粘贴 next_function[0][0] 9 次来达到 Conv2d 的 grad 对象,因为给定的例子但我得到了索引之外的错误元组。那么如何正确索引这些反向传播对象呢?

小智 7

在运行教程中的以下内容后,在PyTorch CNN教程中:

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)
Run Code Online (Sandbox Code Playgroud)

以下代码片段将打印完整图形:

def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

print_graph(loss.grad_fn, 0)
Run Code Online (Sandbox Code Playgroud)