为什么我的简单pytorch网络不能在GPU设备上工作?

Har*_*nen 4 python machine-learning image-processing deep-learning pytorch

我从一个教程中构建了一个简单的网络,但出现此错误:

RuntimeError:类型为torch.cuda.FloatTensor的预期对象,但为参数#4'mat1'找到类型为torch.FloatTensor的对象

有什么帮助吗?谢谢!

import torch
import torchvision

device = torch.device("cuda:0")
root = '.data/'

dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.out = torch.nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

net = Net()
net.to(device)

for i, (inputs, labels) in enumerate(dataloader):
    inputs.to(device)
    out = net(inputs)
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 8

TL; DR
这是修复

inputs = inputs.to(device)  
Run Code Online (Sandbox Code Playgroud)

为什么?!和:
之间有细微的差别torch.nn.Module.to()torch.Tensor.to()虽然Module.to()就地运算符,但Tensor.to()不是。因此

net.to(device)
Run Code Online (Sandbox Code Playgroud)

更改net自身并将其移至device。另一方面

inputs.to(device)
Run Code Online (Sandbox Code Playgroud)

不改变inputs,而是返回一个副本inputs上驻留device。要使用该“在设备上”副本,您需要将其分配给一个变量,因此

inputs = inputs.to(device)
Run Code Online (Sandbox Code Playgroud)