sec*_*ret 4 python deep-learning pytorch
我训练了一个Efficentnet-b6的模型(架构如下):
https://github.com/lukemelas/EfficientNet-PyTorch
现在,我尝试加载我用它训练的模型:
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
Run Code Online (Sandbox Code Playgroud)
但后来我收到以下错误:
_IncompatibleKeys
Run Code Online (Sandbox Code Playgroud)
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
Run Code Online (Sandbox Code Playgroud)
请让我知道我该如何解决这个问题,我错过了什么?谢谢你!
Was*_*mad 10
如果您比较missing_keys和unexpected_keys,您可能会意识到发生了什么。
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
Run Code Online (Sandbox Code Playgroud)
如您所见,模型权重以module.前缀保存。这是因为您已经使用 训练了模型DataParallel。
现在,要在不使用 的情况下加载模型权重DataParallel,您可以执行以下操作。
# original saved file with DataParallel
checkpoint = torch.load(path, map_location=torch.device('cpu'))
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = key.replace("module.", "") # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict, strict=False)
Run Code Online (Sandbox Code Playgroud)
或者,如果您使用 包装模型DataParallel,则不需要上述方法。
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint, strict=False)
Run Code Online (Sandbox Code Playgroud)
尽管不鼓励第二种方法(因为在许多情况下您可能不需要DataParallel)。
| 归档时间: |
|
| 查看次数: |
4125 次 |
| 最近记录: |