Pau*_*aul 2 python-3.x pytorch
我正在尝试在 Pytorch 中加载 MNIST 数据集,并使用内置数据加载器来迭代训练示例。但是,在迭代器上调用 next() 时出现错误。我用 CIFAR10 没有这个问题。
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 128
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dataiter = iter(dataloader)
dataiter.next() # ERROR
# RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
Run Code Online (Sandbox Code Playgroud)
我正在使用 Python 3.7.3 和 PyTorch 1.1.0
MNIST
数据集由灰度图像组成,即每个图像只有1
通道,而 CIFAR10
数据集由彩色图像组成,即每个图像都有3
通道。
因此,如果是MNIST
数据集,请将transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
to替换为transforms.Normalize([0.5], [0.5])
。
归档时间: |
|
查看次数: |
1162 次 |
最近记录: |