为什么在测试时必须使用DataParallel?

lik*_*yoo 3 deep-learning pytorch

在 GPU 上训练,num_gpus 设置为 1?

device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
Run Code Online (Sandbox Code Playgroud)

在CPU上测试?

model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
Run Code Online (Sandbox Code Playgroud)

这将得到正确的结果,但是当我删除时呢?

device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
Run Code Online (Sandbox Code Playgroud)

结果是错误的。

Mic*_*ngo 10

nn.DataParallel包装模型,其中实际模型被分配给module属性。这也意味着状态字典中的键有一个module.前缀。

让我们看一个只有一个卷积的非常简化的版本,看看有什么不同:

class NestedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

model = NestedUNet()

model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])

# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))

model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])
Run Code Online (Sandbox Code Playgroud)

您保存的状态字典nn.DataParallel与常规模型的状态不一致。您正在将当前状态 dict 与加载状态 dict 合并,这意味着加载状态被忽略,因为模型没有任何属于键的属性,而是剩下随机初始化的模型。

为了避免犯这个错误,你不应该合并 state dicts,而是直接将它应用到模型中,在这种情况下,如果键不匹配就会出错。

RuntimeError: Error(s) in loading state_dict for NestedUNet:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".
Run Code Online (Sandbox Code Playgroud)

要使您保存的状态字典兼容,您可以去掉module.前缀:

RuntimeError: Error(s) in loading state_dict for NestedUNet:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".
Run Code Online (Sandbox Code Playgroud)

您还可以通过nn.DataParallel在保存状态之前解开模型来避免此问题,即 save model.module.state_dict()。因此,nn.DataParallel如果您想使用多个 GPU,您始终可以先加载模型及其状态,然后再决定将其放入。


Was*_*mad 5

DataParallel您使用并保存了模型来训练模型。因此,模型权重是用module.前缀存储的。现在,当您加载 without 时DataParallel,您基本上不会加载任何模型权重(模型具有随机权重)。因此,模型的预测是错误的。

我举个例子。

model = nn.Linear(2, 4)
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.state_dict().keys() # => odict_keys(['module.weight', 'module.bias'])
Run Code Online (Sandbox Code Playgroud)

另一方面,

another_model = nn.Linear(2, 4)
another_model.state_dict().keys() # => odict_keys(['weight', 'bias'])
Run Code Online (Sandbox Code Playgroud)

看看按键上的区别OrderedDict

因此,在您的代码中,以下三行代码有效,但没有加载模型权重。

pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
Run Code Online (Sandbox Code Playgroud)

在这里,model_dict有没有前缀的键module.,但当pretrained_dict您不使用 时有DataParalle。因此,pretrained_dict当不使用 DataParallel 时,本质上是空的。


解决方案:如果您想避免使用DataParallel,或者您可以加载权重文件,请创建一个不带模块前缀的新 OrderedDict,然后加载回来。

类似下面的内容将适用于您的情况,而无需使用DataParallel.

# original saved file with DataParallel
model_old = torch.load(path, map_location=dev)

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in model_old.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

# load params
model.load_state_dict(new_state_dict)
Run Code Online (Sandbox Code Playgroud)