根据官方文档,使用train()或eval()会对某些模块产生影响。但是,现在我希望用我的自定义模块实现类似的功能,即它在train()打开时执行某些操作,而在eval()打开时执行一些不同的操作。我怎样才能做到这一点?
是的你可以。
正如您在源代码中看到的,eval()并且train()基本上正在更改一个名为的标志self.training(请注意,它是递归调用的):
def train(self: T, mode: bool = True) -> T:
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T:
return self.train(False)
Run Code Online (Sandbox Code Playgroud)
此标志在每个nn.Module. 如果您的自定义模块继承了这个基类,那么实现您想要的非常简单:
import torch.nn as nn
class MyCustomModule(nn.Module):
def __init__(self):
super().__init__()
# [...]
def forward(self, x):
if self.training:
# train() -> training logic
else:
# eval() -> inference logic
Run Code Online (Sandbox Code Playgroud)