通过模型检查点时 Pytorch Lightning 出现错误

MAC*_*MAC 0 pytorch pytorch-lightning

我正在使用拥抱面部模型训练多标签分类问题。我正在使用 Pytorch Lightning 来训练模型。

这是代码:

当损失最后没有改善时,就会触发提前停止

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
Run Code Online (Sandbox Code Playgroud)

我们可以开始训练过程:

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)


trainer = pl.Trainer(
  logger=logger,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
 checkpoint_callback=checkpoint_callback,
  gpus=1,
  progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,
Run Code Online (Sandbox Code Playgroud)

一旦我运行这个,我就会得到这个错误:

~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
     75             if isinstance(checkpoint_callback, Callback):
     76                 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77             raise MisconfigurationException(error_msg)
     78         if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
     79             raise MisconfigurationException(

MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.
Run Code Online (Sandbox Code Playgroud)

我该如何解决这个问题?

Iva*_*van 5

checkpoint_callback您可以在文档页面中查找参数的描述pl.Trainer

\n
\n

checkpoint_callback(bool) \xe2\x80\x93 如果True,则启用检查点。如果回调中ModelCheckpoint没有用户定义,它将配置默认回调。ModelCheckpoint

\n
\n

您不应该将您的自定义传递ModelCheckpoint给此参数。我相信您要做的就是在列表EarlyStopping中传递和:ModelCheckpointcallbacks

\n
early_stopping_callback = EarlyStopping(monitor=\'val_loss\', patience=2)\n\ncheckpoint_callback = ModelCheckpoint(\n    dirpath="checkpoints",\n    filename="best-checkpoint",\n    save_top_k=1,\n    verbose=True,\n    monitor="val_loss",\n    mode="min")\n\ntrainer = pl.Trainer(\n    logger=logger,\n    callbacks=[checkpoint_callback, early_stopping_callback],\n    max_epochs=N_EPOCHS,\n    gpus=1,\n    progress_bar_refresh_rate=30)\n
Run Code Online (Sandbox Code Playgroud)\n