Pytorch:我们可以在 forward() 函数中直接使用 nn.Module 层吗?

Kuo*_*Kuo 5 python constructor forward pytorch

通常,在构造函数中,我们声明了所有要使用的层。在 forward 函数中,我们定义了模型将如何运行,从输入到输出。

我的问题是,如果直接在函数中调用那些预定义/内置的nn.Modules 会forward()怎样?这个Keras 函数 API风格对于Pytorch是否合法?如果不是,为什么?

更新:以这种方式构建的TestModel确实运行成功,没有报警。但是与传统方式相比,训练损失会缓慢下降。

import torch.nn as nn
from cnn import CNN

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_embeddings = 2020
        self.embedding_dim = 51

    def forward(self, input):
        x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
        # CNN is a customized class and nn.Module subclassed
        # we will ignore the arguments for its instantiation 
        x = CNN(...)(x)
        x = nn.ReLu()(x)
        x = nn.Dropout(p=0.2)(x)
        return output = x
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 5

您需要考虑可训练参数的范围

如果你在forward模型的函数中定义了一个 conv 层,那么这个“层”的范围及其可训练参数对于函数来说是本地的,并且在每次调用该forward方法后都会被丢弃。您无法更新和训练每次forward通过后不断被丢弃的权重。
但是,当 conv 层是您model的范围的成员时,它的范围超出了forward方法,并且只要model对象存在,可训练参数就会持续存在。通过这种方式,您可以更新和训练模型及其权重。