pytorch加载_IncompleteKeys

sec*_*ret 4 python deep-learning pytorch

我训练了一个Efficentnet-b6的模型(架构如下):

https://github.com/lukemelas/EfficientNet-PyTorch

现在,我尝试加载我用它训练的模型:

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
Run Code Online (Sandbox Code Playgroud)

但后来我收到以下错误:

_IncompatibleKeys
Run Code Online (Sandbox Code Playgroud)
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
Run Code Online (Sandbox Code Playgroud)

请让我知道我该如何解决这个问题,我错过了什么?谢谢你!

Was*_*mad 10

如果您比较missing_keysunexpected_keys,您可能会意识到发生了什么。

missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
Run Code Online (Sandbox Code Playgroud)

如您所见,模型权重以module.前缀保存。这是因为您已经使用 训练了模型DataParallel

现在,要在不使用 的情况下加载模型权重DataParallel,您可以执行以下操作。

# original saved file with DataParallel
checkpoint = torch.load(path, map_location=torch.device('cpu'))

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = key.replace("module.", "") # remove `module.`
    new_state_dict[name] = v

# load params
model.load_state_dict(new_state_dict, strict=False)
Run Code Online (Sandbox Code Playgroud)

或者,如果您使用 包装模型DataParallel,则不需要上述方法。

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint, strict=False)
Run Code Online (Sandbox Code Playgroud)

尽管不鼓励第二种方法(因为在许多情况下您可能不需要DataParallel)。