如何为 pytorch 图层指定名称?

Gul*_*zar 9 python machine-learning neural-network deep-learning pytorch

根据上一个问题,我想绘制权重、偏差、激活和梯度以获得与此类似的结果。

使用

for name, param in model.named_parameters():
    summary_writer.add_histogram(f'{name}.grad', param.grad, step_index)
Run Code Online (Sandbox Code Playgroud)

正如上一个问题中所建议的,给出了次优结果,因为层名称类似于'_decoder._decoder.4.weight',这很难遵循,特别是因为架构由于研究而发生变化。4这一次的运行在下一次不会是一样的,而且真的没有意义。

因此,我想为每一层赋予我自己的字符串名称。


我找到了这个Pytorch 论坛讨论,但没有就任何最佳实践达成一致。

为 Pytorch 层分配名称的推荐方法是什么?

即,以各种方式定义的层:

  1. 顺序:
self._seq = nn.Sequential(nn.Linear(1, 2), nn.Linear(3, 4),)
Run Code Online (Sandbox Code Playgroud)
  1. 动态的:
self._dynamic = nn.ModuleList()
    for _ in range(self._n_features): 
        self._last_layer.append(nn.Conv1d(in_channels=5, out_channels=6, kernel_size=3, stride=1, padding=1,),)
Run Code Online (Sandbox Code Playgroud)
  1. 直接的:
self._direct = nn.Linear(7, 8)
Run Code Online (Sandbox Code Playgroud)
  1. 其他我没想到的方式

我希望能够为每个层提供一个字符串名称,以上述每种方式定义。

Szy*_*zke 13

顺序

传递collections.OrderedDict的实例。下面的代码给出conv1.weights, conv1.bias, conv2.weight, conv2.bias(注意缺少torch.nn.ReLU(),请参阅此答案的末尾)。

import collections

import torch

model = torch.nn.Sequential(
    collections.OrderedDict(
        [
            ("conv1", torch.nn.Conv2d(1, 20, 5)),
            ("relu1", torch.nn.ReLU()),
            ("conv2", torch.nn.Conv2d(20, 64, 5)),
            ("relu2", torch.nn.ReLU()),
        ]
    )
)

for name, param in model.named_parameters():
    print(name)
Run Code Online (Sandbox Code Playgroud)

动态的

使用ModuleDict而不是ModuleList

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.whatever = torch.nn.ModuleDict(
            {f"my_name{i}": torch.nn.Conv2d(10, 10, 3) for i in range(5)}
        )
Run Code Online (Sandbox Code Playgroud)

将为我们动态提供每个创建的模块whatever.my_name{i}.weight(或bias)。

直接的

只要你想怎么命名就可以了,这就是它的命名方式

self.my_name_or_whatever = nn.Linear(7, 8)
Run Code Online (Sandbox Code Playgroud)

你没有想过

  • 如果你想绘制权重、偏差及其梯度,你可以沿着这条路线走
  • 您无法以这种方式绘制激活(或激活的输出)。使用PyTorch hooks代替(如果你想要每层梯度通过网络时也使用这个)

对于最后一个任务,您可以使用第三方库torchfunc(免责声明:我是作者)或直接编写您自己的钩子。