model.train(False) 和 required_grad = False 之间的区别

Dr.*_*ick 5 machine-learning deep-learning pytorch

我使用 Pytorch 库,正在寻找一种方法来冻结模型中的权重和偏差。

我看到了这两个选项:

  1. model.train(False)

  2. for param in model.parameters(): param.requires_grad = False

有什么区别(如果有的话)以及我应该使用哪一个来冻结模型的当前状态?

tri*_*ror 4

他们非常不同。

与反向传播过程无关,当您训练或评估模型时,某些层具有不同的行为。在 pytorch 中,只有 2 个:BatchNorm(我认为在评估时会停止更新其运行平均值和偏差)和 Dropout(仅在训练模式下丢弃值)。因此model.train()model.eval()等效地model.train(false))只需设置一个布尔标志来告诉这两个层“冻结自己”。请注意,这两层没有任何受向后操作影响的参数(我认为在前向传递过程中批归一化缓冲区张量发生了变化)

另一方面,将所有参数设置为“requires_grad=false”只是告诉 pytorch 停止记录反向传播的梯度。这不会影响 BatchNorm 和 Dropout 层

如何冻结模型有点取决于您的用例,但我想说最简单的方法是使用torch.jit.trace。这将创建您的模型的冻结副本,完全处于您调用时的状态trace。您的模型不受影响。

通常,你会打电话

model.eval()
traced_model = torch.jit.trace(model, input)
Run Code Online (Sandbox Code Playgroud)