Har*_*nen 4 python machine-learning image-processing deep-learning pytorch
我从一个教程中构建了一个简单的网络,但出现此错误:
RuntimeError:类型为torch.cuda.FloatTensor的预期对象,但为参数#4'mat1'找到类型为torch.FloatTensor的对象
有什么帮助吗?谢谢!
import torch
import torchvision
device = torch.device("cuda:0")
root = '.data/'
dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.out = torch.nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.out(x)
return x
net = Net()
net.to(device)
for i, (inputs, labels) in enumerate(dataloader):
inputs.to(device)
out = net(inputs)
Run Code Online (Sandbox Code Playgroud)
TL; DR
这是修复
inputs = inputs.to(device)
Run Code Online (Sandbox Code Playgroud)
为什么?!和:
之间有细微的差别torch.nn.Module.to(),torch.Tensor.to()虽然Module.to()是就地运算符,但Tensor.to()不是。因此
net.to(device)
Run Code Online (Sandbox Code Playgroud)
更改net自身并将其移至device。另一方面
inputs.to(device)
Run Code Online (Sandbox Code Playgroud)
不改变inputs,而是返回一个副本的inputs上驻留device。要使用该“在设备上”副本,您需要将其分配给一个变量,因此
inputs = inputs.to(device)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1962 次 |
| 最近记录: |