相关疑难解决方法(0)

Pytorch - 无法切片 torchvision MNIST 数据集

在Pytorch中,当使用torchvision的MNIST数据集时,我们可以得到一个数字,如下所示:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
                          transform=tsfm)

digit_12 = mnist_ds[12]
Run Code Online (Sandbox Code Playgroud)

虽然可以对许多数据集进行切片,但我们不能对这个数据集进行切片:

>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.
Run Code Online (Sandbox Code Playgroud)

这是Image.fromarray()由于getItem().

是否可以在不使用 Dataloader 的情况下使用 MNIST 数据集?


PS:我想避免使用 Dataloader 的原因是一次向 GPU 发送一批数据会减慢训练速度。我更喜欢一次性将整个数据集发送到 GPU。为此,我需要访问整个转换后的数据集。

python dataset slice pytorch

8
推荐指数
1
解决办法
4903
查看次数

获取pytorch数据集的子集

我有一个网络,我想在一些数据集上训练(例如,说CIFAR10).我可以通过创建数据加载器对象

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

我的问题如下:假设我想进行几次不同的训练迭代.假设我首先想要在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推.为此,我需要能够访问这些图像.不幸的是,它似乎trainset不允许这种访问.也就是说,尝试做trainset[:1000]或更多一般trainset[mask]会抛出错误.

我可以做

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
Run Code Online (Sandbox Code Playgroud)

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改,trainset.train_data所以我需要重新定义trainset).有没有办法避免它?

理想情况下,我希望有一些"等同"的东西

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

python machine-learning neural-network torch pytorch

7
推荐指数
2
解决办法
9886
查看次数