如何从 PyTorch 的 FashionMNIST 数据集中获取特定类?

Nur*_*nab 7 python pytorch

FashionMNIST 数据集有 10 个不同的输出类别。如何获取仅包含特定类的数据集的子集?就我而言,我只想要运动鞋、套头衫、凉鞋和衬衫类别的图像(它们的类别分别为 7、2、5 和 6)。

\n

这就是我加载数据集的方式。

\n

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())

\n

我\xe2\x80\x99ve遵循的方法如下。\n逐一迭代数据集,然后将返回的元组中的第一个元素(即类)与我所需的类进行比较。我\xe2\x80\x99m 卡在这里。如果返回的值为 true,我如何将此观察结果追加/添加到空数据集中?

\n
sneaker = 0\npullover = 0\nsandal = 0\nshirt = 0\nfor i in range(60000):\n    if train_dataset_full[i][1] == 7:\n        sneaker += 1\n    elif train_dataset_full[i][1] == 2:\n        pullover += 1\n    elif train_dataset_full[i][1] == 5:\n        sandal += 1\n    elif train_dataset_full[i][1] == 6:\n        shirt += 1\n
Run Code Online (Sandbox Code Playgroud)\n

现在,代替sneaker += 1pullover += 1sandal += 1shirt += 1想做类似的事情empty_dataset.append(train_dataset_full[i])或类似的事情。

\n

如果上述方法不正确,请建议其他方法。

\n

Nur*_*nab 10

终于找到了答案。

dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
# Selecting classes 7, 2, 5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
Run Code Online (Sandbox Code Playgroud)