如何对pytorch中的变量应用指数移动平均衰减?

jef*_*jef 5 deep-learning pytorch

我正在阅读以下论文。它对变量使用 EMA 衰减。
用于机器理解的双向注意力流

在训练过程中,模型所有权重的移动平均值保持为 0.999 的指数衰减率。

他们使用TensorFlow,我找到了EMA的相关代码。
https://github.com/allenai/bi-att-flow/blob/master/basic/model.py#L229

在 PyTorch 中,如何将 EMA 应用于变量?

Bru*_*hou -7

移动平均线是梯度下降中动量的关键概念。

PyTorch 文档中你可以找到:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

将参数更改momentum为您想要的值。