如何使用 Tensorflow 对象检测 API 继续训练对象检测模型?

Sim*_*que 6 machine-learning tensorflow google-cloud-ml object-detection-api

我正在使用Tensorflow 对象检测 API来训练使用迁移学习的对象检测模型。具体来说,我正在使用模型 zoo 中的 ssd_mobilenet_v1_fpn_coco,并使用提供示例管道,当然用实际链接替换了占位符,我的训练和评估 tfrecords 和标签。

我能够使用上述管道成功地在我的约 5000 张图像(和相应的边界框)上训练模型(如果相关的话,我主要在 TPU 上使用 Google 的 ML 引擎)。

现在,我准备了大约 2000 张额外的图像,并希望继续使用这些新图像训练我的模型,而无需从头开始(训练初始模型花费了大约 6 小时的 TPU 时间)。我怎样才能做到这一点?

net*_*sam 5

你有两个选择,在这两个你需要改变input_pathtrain_input_reader新的数据集:

  1. 在训练配置中指定要微调的检查点时,请指定训练模型的检查点
train_config{
    fine_tune_checkpoint: <path_to_your_checkpoint>
    fine_tune_checkpoint_type: "detection"
    load_all_detection_checkpoint_vars: true
}
Run Code Online (Sandbox Code Playgroud)
  1. 只需继续使用train_input_readermodel_dir之前型号相同的配置(除了)。这样,API 将创建一个图形并检查检查点是否已存在于model_dir图形中并适合该图形。如果是这样 - 它会恢复它并继续训练它。

编辑:fine_tune_checkpoint_type 之前被错误地设置为 true,而在一般情况下它应该是“检测”或“分类”,在这种特定情况下应该是“检测”。感谢 Krish 的关注。


Vla*_*-HC 1

我还没有在新数据集上重新训练对象检测模型,但看起来增加train_config.num_steps配置文件中的训练步骤数以及在 tfrecord 文件中添加图像应该足够了。