训练 PyTorch 脚本直至收敛的标准方法是什么?

Cha*_*ker 6 python machine-learning deep-learning pytorch pytorch-lightning

检测模型是否收敛的标准方法是什么?我打算记录 5 次损失,每次损失有 95 个置信区间,如果他们都同意,那么我\xe2\x80\x99d 停止脚本。我假设收敛之前的训练必须已经在 PyTorch 或 PyTorch Lightning 中的某个地方实现。我不需要一个完美的解决方案,只需自动执行此操作的标准方法 - 即收敛时停止。

\n

我的解决方案很容易实现。一旦创建了一个标准并将减少更改为none。然后它将输出一个大小为 的张量[B]。每次记录时都会记录下来,并且它是 95 置信区间(如果您愿意,也可以是 std,但这精度要低得多)。然后,每次添加新损失及其置信区间时,请确保其大小保持为 5(或 10),并且 5 个损失彼此之间的 CI 范围在 95 以内。那么如果这是真的就停下来。

\n

您可以使用以下方法计算 CI:

\n
def torch_compute_confidence_interval(data: Tensor,\n                                           confidence: float = 0.95\n                                           ) -> Tensor:\n    """\n    Computes the confidence interval for a given survey of a data set.\n    """\n    n = len(data)\n    mean: Tensor = data.mean()\n    # se: Tensor = scipy.stats.sem(data)  # compute standard error\n    # se, mean: Tensor = torch.std_mean(data, unbiased=True)  # compute standard error\n    se: Tensor = data.std(unbiased=True) / (n**0.5)\n    t_p: float = float(scipy.stats.t.ppf((1 + confidence) / 2., n - 1))\n    ci = t_p * se\n    return mean, ci\n
Run Code Online (Sandbox Code Playgroud)\n

您可以按如下方式创建标准:

\n
loss: nn.Module = nn.CrossEntropyLoss(reduction=\'none\')\n
Run Code Online (Sandbox Code Playgroud)\n

所以现在火车损失已经很大了[B]

\n
\n

请注意,我知道如何用固定数量的纪元进行训练,所以我并不是真的在寻找它 - 只是当模型看起来收敛时何时停止的停止标准,当一个人查看他们的学习曲线时会做什么但自动地。

\n
\n

参考:\n https://forums.pytorchlightning.ai/t/what-is-the-standard-way-to-halt-a-script-when-it-has-converged/1415

\n

Mik*_*e B 1

在训练器中设置 EarlyStopping ( https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping ) 回调

checkpoint_callbacks = [
    EarlyStopping(
        monitor="val_f1_score",
        min_delta=0.01,
        patience=10,  # NOTE no. val epochs, not train epochs
        verbose=False,
        mode="min",
    ),
]

trainer = pl.Trainer(callbacks=callbacks)
Run Code Online (Sandbox Code Playgroud)

这将监视val_f1_score训练期间的变化(请注意,您必须self.log("val_f1_score", val_f1) 在您的 中记录该值pl.LightningModule)。如果符合改进条件的最小数量变化 ( min_delta) 超过指定的 epoch 数,它将停止训练patience