Kuo*_*Kuo 5 python constructor forward pytorch
通常,在构造函数中,我们声明了所有要使用的层。在 forward 函数中,我们定义了模型将如何运行,从输入到输出。
我的问题是,如果直接在函数中调用那些预定义/内置的nn.Modules 会forward()怎样?这个Keras 函数 API风格对于Pytorch是否合法?如果不是,为什么?
更新:以这种方式构建的TestModel确实运行成功,没有报警。但是与传统方式相比,训练损失会缓慢下降。
import torch.nn as nn
from cnn import CNN
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51
def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x
Run Code Online (Sandbox Code Playgroud)
您需要考虑可训练参数的范围。
如果你在forward模型的函数中定义了一个 conv 层,那么这个“层”的范围及其可训练参数对于函数来说是本地的,并且在每次调用该forward方法后都会被丢弃。您无法更新和训练每次forward通过后不断被丢弃的权重。
但是,当 conv 层是您model的范围的成员时,它的范围超出了forward方法,并且只要model对象存在,可训练参数就会持续存在。通过这种方式,您可以更新和训练模型及其权重。
| 归档时间: |
|
| 查看次数: |
2942 次 |
| 最近记录: |