有没有办法将 torch.nn.DataParallel 与 CPU 一起使用?

mos*_*rze 5 python parallel-processing pytorch

我正在尝试更改一些 PyTorch 代码,以便它可以在 CPU 上运行。

该模型经过训练,torch.nn.DataParallel()因此当我加载预训练模型并尝试使用它时,我必须使用nn.DataParallel()我目前正在做的事情,如下所示:

device = torch.device("cuda:0")
net = nn.DataParallel(net, device_ids=[0])
net.load_state_dict(torch.load(PATH))
net.to(device)
Run Code Online (Sandbox Code Playgroud)

然而,当我将我的手电筒设备切换到 CPU 后,如下所示:

device = torch.device('cpu')
net = nn.DataParallel(net, device_ids=[0])
net.load_state_dict(torch.load(PATH))
net.to(device)
Run Code Online (Sandbox Code Playgroud)

我收到这个错误:

File "C:\My\Program\win-py362-venv\lib\site-packages\torch\nn\parallel\data_parallel.py", line 156, in forward
    "them on device: {}".format(self.src_device_obj, t.device))
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu
Run Code Online (Sandbox Code Playgroud)

我假设它仍在寻找 CUDA,因为这就是device_ids设置的内容,但有没有办法让它使用 CPU?PyTorch 存储库中的这篇文章让我认为我可以,但它没有解释如何做到。

如果没有,是否有其他方法可以在您的 CPU 上使用通过 DataParallel 训练的模型?

小智 10

当您使用torch.nn.DataParallel()它时,它会在模块级别实现数据并行性。

根据文档:

在运行此 DataParallel 模块之前,并行化模块必须在 device_ids[0] 上拥有其参数和缓冲区。

因此,即使您正在.to(torch.device('cpu'))这样做,仍然希望将数据传递到 GPU。

但是,由于DataParallel是一个容器,因此您可以绕过它并通过执行以下操作仅获取原始模块:

net = net.module.to(device)
Run Code Online (Sandbox Code Playgroud)

现在它将访问您在应用DataParallel容器之前定义的原始模块。