PyTorch:用于训练和测试/验证的不同前向方法

qwe*_*rtz 3 transformer-model neural-network python-3.x pytorch seq2seq

我目前正在尝试扩展基于 FairSeq/PyTorch的模型。在训练期间,我需要训练两个编码器:一个使用目标样本,另一个使用源样本。

所以当前的 forward 函数是这样的:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out
Run Code Online (Sandbox Code Playgroud)

基于这个想法,我想要这样的东西:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out
Run Code Online (Sandbox Code Playgroud)

有没有办法做到这一点?

编辑:这些是我的约束,因为我需要扩展FairseqEncoderDecoderModel

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 
Run Code Online (Sandbox Code Playgroud)

编辑 2:传递给 Fairseq 中前向函数的参数可以通过实现您自己的 Criterion 来更改,例如参见CrossEntropyCriterion,其中sample['net_input']传递给__call__调用该forward方法的模型函数。

Szy*_*zke 6

首先,您应该始终使用和定义forward您在torch.nn.Module实例上调用的其他一些方法。

绝对不要eval()trsvchn所示那样重载,因为它是 PyTorch 定义的评估方法(请参阅此处)。此方法允许将模型内的层置于评估模式(例如,对Dropout或推理模式等层的特定更改BatchNorm)。

此外,您应该使用__call__魔术方法调用它。为什么?因为钩子和其他 PyTorch 特定的东西是以这种方式正确注册的。

其次,不要mode按照@Anant Mittal 的建议使用某些外部字符串变量。这就是trainPyTorch 中的变量的用途,通过它来区分模型是处于eval模式还是train模式是标准的。

话虽如此,你最好这样做:

import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
Run Code Online (Sandbox Code Playgroud)

您可以(并且可以说应该)将上述方法拆分为两个单独的方法,但这还不错,因为该函数相当短且可读。如果可能的话,只要坚持 PyTorch 的处理方式,而不是一些临时解决方案。不,反向传播不会有问题,为什么会有一个?