MAC*_*MAC 0 pytorch pytorch-lightning
我正在使用拥抱面部模型训练多标签分类问题。我正在使用 Pytorch Lightning 来训练模型。
这是代码:
当损失最后没有改善时,就会触发提前停止
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
Run Code Online (Sandbox Code Playgroud)
我们可以开始训练过程:
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min"
)
trainer = pl.Trainer(
logger=logger,
callbacks=[early_stopping_callback],
max_epochs=N_EPOCHS,
checkpoint_callback=checkpoint_callback,
gpus=1,
progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,
Run Code Online (Sandbox Code Playgroud)
一旦我运行这个,我就会得到这个错误:
~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
75 if isinstance(checkpoint_callback, Callback):
76 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77 raise MisconfigurationException(error_msg)
78 if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
79 raise MisconfigurationException(
MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.
Run Code Online (Sandbox Code Playgroud)
我该如何解决这个问题?
checkpoint_callback
您可以在文档页面中查找参数的描述pl.Trainer
:
\n\n\n
checkpoint_callback
(bool) \xe2\x80\x93 如果True
,则启用检查点。如果回调中ModelCheckpoint
没有用户定义,它将配置默认回调。ModelCheckpoint
您不应该将您的自定义传递ModelCheckpoint
给此参数。我相信您要做的就是在列表EarlyStopping
中传递和:ModelCheckpoint
callbacks
early_stopping_callback = EarlyStopping(monitor=\'val_loss\', patience=2)\n\ncheckpoint_callback = ModelCheckpoint(\n dirpath="checkpoints",\n filename="best-checkpoint",\n save_top_k=1,\n verbose=True,\n monitor="val_loss",\n mode="min")\n\ntrainer = pl.Trainer(\n logger=logger,\n callbacks=[checkpoint_callback, early_stopping_callback],\n max_epochs=N_EPOCHS,\n gpus=1,\n progress_bar_refresh_rate=30)\n
Run Code Online (Sandbox Code Playgroud)\n
归档时间: |
|
查看次数: |
5339 次 |
最近记录: |