如何在 Tensorflow 对象检测 API 中存储最佳模型检查点,而不仅仅是最新的 5 个?

Pio*_*ski 7 object-detection tensorflow

我正在 WIDER FACE 数据集上训练 MobileNet,但遇到了无法解决的问题。TF 对象检测 API 仅在traindir 中存储最后 5 个检查点,但我想做的是保存相对于 mAP 指标的最佳模型(或至少train在删除之前在dir 中保留更多模型)。例如,今天我在第二天晚上的训练后查看了 Tensorboard,我发现隔夜模型过度拟合,我无法恢复最佳检查点,因为它已经被删除了。

编辑:我只使用Tensorflow Object Detection API,它默认保存我指向的 train dir 中的最后 5 个检查点。我寻找一些配置参数或任何会改变这种行为的东西。

有没有人在代码/配置参数中有一些修复来设置/解决方法?似乎我遗漏了一些东西,很明显,实际上重要的是最好的模型,而不是最新的模型(可能会过拟合)。

谢谢!

Dav*_*sia 6

您可以修改(在您的 fork 中硬编码或打开拉取请求并将选项添加到 protos)传递给tf.train.Saver的参数:

https://github.com/tensorflow/models/blob/master/research/object_detection/legacy/trainer.py#L376-L377

您可能想要设置:

  • max_to_keep:要保留的最近检查点的最大数量。默认为 5。
  • keep_checkpoint_every_n_hours:保持检查点的频率。默认为 10,000 小时。