在pytorch中构建参数组

Bla*_*ade 5 pytorch

torch.optim 文档中,指出可以使用不同的优化超参数对模型参数进行分组和优化。它说

\n
\n

例如,当想要指定每一层的学习率时,这非常有用:

\n
optim.SGD([\n                {\'params\': model.base.parameters()},\n                {\'params\': model.classifier.parameters(), \'lr\': 1e-3}\n            ], lr=1e-2, momentum=0.9)\n
Run Code Online (Sandbox Code Playgroud)\n

这意味着model.base\xe2\x80\x99s 个参数将使用默认学习率为1e-2model.classifier\xe2\x80\x99s 个参数将使用\n学习率为1e-3,动量0.9将用于所有参数。

\n
\n

我想知道如何定义这样具有parameters()属性的组。我想到的是以下形式的东西

\n
class MyModel(nn.Module):\n    def __init__(self):\n        super(MyModel, self).__init__()\n        self.base()\n        self.classifier()\n\n        self.relu = nn.ReLU()\n\n    def base(self):\n        self.fc1 = nn.Linear(1, 512)\n        self.fc2 = nn.Linear(512, 264)\n\n    def classifier(self):\n        self.fc3 = nn.Linear(264, 128)\n        self.fc4 = nn.Linear(128, 964)\n\n    def forward(self, y0):\n\n        y1 = self.relu(self.fc1(y0))\n        y2 = self.relu(self.fc2(y1))\n        y3 = self.relu(self.fc3(y2))\n\n        return self.fc4(y3)\n
Run Code Online (Sandbox Code Playgroud)\n

我应该如何修改上面的代码片段才能获得model.base.parameters()?定义 a并将所需层的s 和esnn.ParameterList显式添加到该列表的唯一方法是吗?最佳实践是什么?weightbias

\n

Iva*_*van 9

我将展示解决此问题的三种方法。但最终,这取决于个人喜好。


- 使用 分组参数nn.ModuleDict

我在这里注意到一个答案,用于nn.Sequential对图层进行分组,允许使用 的parameters属性来定位模型的不同部分nn.Sequential。事实上base,分类器可能不仅仅是连续层。我相信更通用的方法是让模块保持原样,而是初始化一个附加nn.ModuleDict模块,该模块将包含优化组在单独的nn.ModuleLists 中排序的所有参数:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(1, 512)
        self.fc2 = nn.Linear(512, 264)
        self.fc3 = nn.Linear(264, 128)
        self.fc4 = nn.Linear(128, 964)

        self.params = nn.ModuleDict({
            'base': nn.ModuleList([self.fc1, self.fc2]),
            'classifier': nn.ModuleList([self.fc3, self.fc4])})

    def forward(self, y0):
        y1 = self.relu(self.fc1(y0))
        y2 = self.relu(self.fc2(y1))
        y3 = self.relu(self.fc3(y2))
        return self.fc4(y3)
Run Code Online (Sandbox Code Playgroud)

然后你可以定义你的优化器:

optim.SGD([
    {'params': model.params.base.parameters()},
    {'params': model.params.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
Run Code Online (Sandbox Code Playgroud)

请注意MyModel'parameters生成器不会包含重复的参数。


- 创建用于访问参数组的接口。

一种不同的解决方案是在 中提供一个接口nn.Module将参数分成组:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(1, 512)
        self.fc2 = nn.Linear(512, 264)
        self.fc3 = nn.Linear(264, 128)
        self.fc4 = nn.Linear(128, 964)

    def forward(self, y0):
        y1 = self.relu(self.fc1(y0))
        y2 = self.relu(self.fc2(y1))
        y3 = self.relu(self.fc3(y2))
        return self.fc4(y3)

    def base_params(self):
        return chain(m.parameters() for m in [self.fc1, self.fc2])

    def classifier_params(self):
        return chain(m.parameters() for m in [self.fc3, self.fc4])
Run Code Online (Sandbox Code Playgroud)

导入itertools.chainchain.

然后定义你的优化器:

optim.SGD([
    {'params': model.base_params()},
    {'params': model.classifier_params(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
Run Code Online (Sandbox Code Playgroud)

- 使用nn.Module子代。

最后,您可以将模块部分定义为子模块(这里它归结为方法nn.Sequential,但您可以将其推广到任何子模块)。

class Base(nn.Sequential):
    def __init__(self):
        super().__init__(nn.Linear(1, 512),
                         nn.ReLU(),
                         nn.Linear(512, 264),
                         nn.ReLU())

class Classifier(nn.Sequential):
    def __init__(self):
        super().__init__(nn.Linear(264, 128),
                         nn.ReLU(),
                         nn.Linear(128, 964))

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.base = Base()
        self.classifier = Classifier()

    def forward(self, y0):
        features = self.base(y0)
        out = self.classifier(features)
        return out
Run Code Online (Sandbox Code Playgroud)

您再次可以使用与第一种方法相同的接口:

optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
Run Code Online (Sandbox Code Playgroud)

我认为这是最佳实践。但是,它迫使您将每个组件定义为单独的nn.Module,这在尝试更复杂的模型时可能会很麻烦。