Pytorch - 无法切片 torchvision MNIST 数据集

u2g*_*les 8 python dataset slice pytorch

在Pytorch中,当使用torchvision的MNIST数据集时,我们可以得到一个数字,如下所示:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
                          transform=tsfm)

digit_12 = mnist_ds[12]
Run Code Online (Sandbox Code Playgroud)

虽然可以对许多数据集进行切片,但我们不能对这个数据集进行切片:

>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.
Run Code Online (Sandbox Code Playgroud)

这是Image.fromarray()由于getItem().

是否可以在不使用 Dataloader 的情况下使用 MNIST 数据集?


PS:我想避免使用 Dataloader 的原因是一次向 GPU 发送一批数据会减慢训练速度。我更喜欢一次性将整个数据集发送到 GPU。为此,我需要访问整个转换后的数据集。

Jat*_*aki 2

接口Dataset只需要

所有子类都应该重写__len__,它提供数据集的大小,并且,支持从到排除__getitem__范围内的整数索引。0len(self)

这显然没有提到切片 - 其他数据集的切片行为是一个额外的功能。如果您想立即获取整个数据,您可以查找实现仅使用mnist.datamnist.targets末尾定义的张量__init__

如果你想转换数据,你可以使用

data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)
Run Code Online (Sandbox Code Playgroud)

或一次全部变换mnist.data张量(尽管这不适用于torchvision.transform变换)。