我可以让我的自定义 pytorch 模块在调用 train() 或 eval() 时表现不同吗?

ihd*_*hdv 6 python pytorch

根据官方文档,使用train()eval()会对某些模块产生影响。但是,现在我希望用我的自定义模块实现类似的功能,即它在train()打开时执行某些操作,而在eval()打开时执行一些不同的操作。我怎样才能做到这一点?

Ber*_*iel 6

是的你可以。

正如您在源代码中看到的,eval()并且train()基本上正在更改一个名为的标志self.training(请注意,它是递归调用的):

def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

def eval(self: T) -> T:
    return self.train(False)
Run Code Online (Sandbox Code Playgroud)

此标志在每个nn.Module. 如果您的自定义模块继承了这个基类,那么实现您想要的非常简单:

import torch.nn as nn


class MyCustomModule(nn.Module):
    def __init__(self):
        super().__init__()
        # [...]

    def forward(self, x):
        if self.training:
            # train() -> training logic
        else:
            # eval()  -> inference logic
Run Code Online (Sandbox Code Playgroud)