Pytorch:如何获取图中的所有张量

Ten*_*rye 9 python deep-learning pytorch

我想访问图形的所有张量实例。例如,我可以检查张量是否分离或者我可以检查大小。它可以在tensorflow 中完成。

想要图形的可视化。

Iva*_*van 2

您可以在运行时访问整个计算图。为此,您可以使用钩子。这些函数插入到nn.Modules 上,用于推理和反向传播。

在推理时,您可以使用 挂钩回调函数register_forward_hook。类似地,对于反向传播,您可以使用register_full_backward_hook.
注意:从 PyTorch 1.8.0 register_backward_hook版开始,已弃用。

通过这两个函数,您基本上可以访问计算图上的任何张量。是否要打印所有张量、打印形状,甚至插入断点进行研究,完全取决于您。

这是一个可能的实现:

def forward_hook(module, input, output):
    # ...
Run Code Online (Sandbox Code Playgroud)

参数由 PyTorch 作为元组input传递,并将包含传递给挂钩模块的前向函数的所有参数。

def backward_hook(module, grad_input, grad_output):
    # ...
Run Code Online (Sandbox Code Playgroud)

对于后向钩子, 和grad_input都是grad_output,并且根据模型的层具有不同的形状。

然后您可以将这些回调挂钩到任何现有的nn.Module. 例如,您可以循环模型中的所有子模块:

for module in model.children():
    module.register_forward_hook(forward_hook)
    module.register_full_backward_hook(backward_hook)
Run Code Online (Sandbox Code Playgroud)

要获取模块的名称,您可以包装钩子以将名称括起来并在模型的上循环named_modules

def forward_hook(name):
    def hook(module, x, y):
        print(f'{name}: {[tuple(i.shape) for i in x]} -> {list(y.shape)}')
    return hook

for name, module in model.named_children():
    module.register_forward_hook(forward_hook(name))
Run Code Online (Sandbox Code Playgroud)

它可以在推理上打印以下内容:

fc1: [(1, 100)] -> (1, 10)
fc2: [(1, 10)] -> (1, 5)
fc3: [(1, 5)] -> (1, 1)
Run Code Online (Sandbox Code Playgroud)

至于模型的参数,您可以通过调用 轻松访问两个钩子中给定模块的参数module.parameters。这将返回一个生成器。