Tensorflow 对象检测 API - 模型的微调是如何工作的?

dav*_*vid 7 machine-learning object-detection tensorflow

这是一个关于 Tensorflow Object-Detection API 的更普遍的问题。

我正在使用这个 API,更具体地说,我将模型微调到我的数据集。根据 API 的描述,我使用该model_main.py函数从给定的检查点/冻结图重新训练模型。

但是,我不清楚微调在 API 中是如何工作的。最后一层的重新初始化是自动发生的还是我必须实现类似的东西?在README文件中,我没有找到有关此主题的任何提示。也许有人可以帮助我。

dan*_*ang 21

从stratch训练或从检查点训练model_main.py是主程序,除此程序外,您只需要一个正确的管道配置文件。

所以对于微调,可以分为两个步骤,恢复权重和更新权重。这两个步骤都可以根据train proto文件进行自定义配置,这个proto对应train_config的是pipeline中的配置文件。

train_config: {
   batch_size: 24
   optimizer { }
   fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
   fine_tune_checkpoint_type:  "detection"
   # Note: The below line limits the training process to 200K steps, which we
   # empirically found to be sufficient enough to train the pets dataset. This
   # effectively bypasses the learning rate schedule (the learning rate will
   # never decay). Remove the below line to train indefinitely.
   num_steps: 200000
   data_augmentation_options {}
 }
Run Code Online (Sandbox Code Playgroud)

第一步,恢复权重。

在这一步中,您可以通过设置来配置要恢复的变量fine_tune_checkpoint_type,选项为detectionclassification。通过将其设置为detection本质上,您可以从检查点恢复几乎所有变量,通过将其设置为classification,仅feature_extractor恢复范围内的变量,(骨干网络中的所有层,如 VGG、Resnet、MobileNet,它们被称为特征提取器) .

以前这是由from_detection_checkpoint和 控制的load_all_detection_checkpoint_vars,但这两个字段已被弃用。

还要注意,配置好之后fine_tune_checkpoint_type,实际的恢复操作会检查图中的变量是否存在于checkpoint中,如果不存在,则使用例行初始化操作来初始化该变量。

举个例子,假设你想微调一个ssd_mobilenet_v1_custom_data模型并且你下载了 checkpoint ssd_mobilenet_v1_coco,当你设置fine_tune_checkpoint_type: detection,那么图中所有在 checkpoint 文件中也可用的变量将被恢复,并且框预测器(最后一层)权重也将恢复。但是,如果您设置了fine_tune_checkpoint_type: classification,则只会mobilenet恢复图层的权重。但是,如果您使用不同的模型检查点,例如faster_rcnn_resnet_xxx,那么由于图表中的变量在检查点中不可用,您将看到输出日志说Variable XXX is not available in checkpoint警告,并且它们不会被恢复。

第二步,更新权重

现在你已经恢复了所有的权重,并且你想在你自己的数据集上继续训练(微调),通常这应该就足够了。

但是如果你想尝试一些东西并且你想在训练过程中冻结一些层,那么你可以通过设置freeze_variables. 假设您想冻结 mobilenet 的所有权重并仅更新框预测器的权重,您可以设置为不更新名称freeze_variables: [feature_extractor]中包含的所有变量feature_extractor。有关详细信息,请参阅我写的另一个答案

因此,要在自定义数据集上微调模型,您应该准备一个自定义配置文件。您可以从示例配置文件开始,然后修改一些字段以满足您的需要。