相关疑难解决方法(0)

为什么在测试时必须使用DataParallel?

在 GPU 上训练,num_gpus 设置为 1?

device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
Run Code Online (Sandbox Code Playgroud)

在CPU上测试?

model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
Run Code Online (Sandbox Code Playgroud)

这将得到正确的结果,但是当我删除时呢?

device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
Run Code Online (Sandbox Code Playgroud)

结果是错误的。

deep-learning pytorch

3
推荐指数
2
解决办法
1164
查看次数

标签 统计

deep-learning ×1

pytorch ×1