TF Slim:在自定义数据集上微调mobilenet v2

Rav*_*avi 2 python tensorflow tensorflow-slim

我正在尝试在用于图像分类任务的自定义数据集上微调Mobilenet_v2_1.4_224模型。我正在关注本教程TensorFlow-Slim图像分类库。我已经创建了.tfrecord训练和验证文件。当我尝试从现有检查点进行微调时,出现以下错误:

InvalidArgumentError(请参阅上面的回溯):Assign需要两个张量的形状匹配。lhs shape = [1,1,24,144] rhs shape = [1,1,32,192] [[节点:save / Assign_149 =分配[T = DT_FLOAT,_class = [“ loc:@ MobilenetV2 / expanded_conv_2 / expand / weights”] ,use_locking = true,validate_shape = true,_device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](MobilenetV2 / expanded_conv_2 / expand / weights,保存/恢复V2:149)]]

我使用的微调脚本是:

DATASET_DIR = G:\数据集

TRAIN_DIR = G:\ Dataset \ emotion-models \ mobilenet_v2

CHECKPOINT_PATH = C:\ Users \ lenovo \ Desktop \ mobilenet_v2 \ mobilenet_v2_1.4_224.ckpt

python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=emotion \
--dataset_split_name=train \
--model_name=mobilenet_v2 \
--train_image_size=224 \
--clone_on_cpu=True \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=MobilenetV2/Logits \
--trainable_scopes=MobilenetV2/Logits
Run Code Online (Sandbox Code Playgroud)

我怀疑该错误是由于最后两个参数“ checkpoint_exclude_scopes”或“ trainable_scopes”引起的。

我知道通过删除最后2层并为自定义数据集分类创建我们自己的softmax层,这2个参数已用于传递学习。但是我不确定是否要为它们传递正确的值。

sid*_*rth 5

要重新训练模型,您必须微调您的自定义班级数量

MobilenetV2 / Predictions和MobilenetV2 / predics

--checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
--trainable_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
Run Code Online (Sandbox Code Playgroud)

在mobilenet和mobilenet_base的 mobilenet_v2.py中,depth_multiplier = 1时,应将其更改为1.4

@slim.add_arg_scope 
def mobilenet_base(input_tensor, depth_multiplier=1.4, **kwargs): 
"""Creates base of the mobilenet (no pooling and no logits) .""" 
return mobilenet(input_tensor,
                 depth_multiplier=depth_multiplier,
                 base_only=True, **kwargs)

@slim.add_arg_scope 
def mobilenet(input_tensor,
                  num_classes=1001,
                  depth_multiplier=1.4,
                  scope='MobilenetV2',
                  conv_defs=None,
                  finegrain_classification_mode=False,
                  min_depth=None,
                  divisible_by=None,
                  **kwargs):
Run Code Online (Sandbox Code Playgroud)