FashionMNIST 数据集有 10 个不同的输出类别。如何获取仅包含特定类的数据集的子集?就我而言,我只想要运动鞋、套头衫、凉鞋和衬衫类别的图像(它们的类别分别为 7、2、5 和 6)。
\n这就是我加载数据集的方式。
\ntrain_dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
我\xe2\x80\x99ve遵循的方法如下。\n逐一迭代数据集,然后将返回的元组中的第一个元素(即类)与我所需的类进行比较。我\xe2\x80\x99m 卡在这里。如果返回的值为 true,我如何将此观察结果追加/添加到空数据集中?
\nsneaker = 0\npullover = 0\nsandal = 0\nshirt = 0\nfor i in range(60000):\n if train_dataset_full[i][1] == 7:\n sneaker += 1\n elif train_dataset_full[i][1] == 2:\n pullover += 1\n elif train_dataset_full[i][1] == 5:\n sandal += 1\n elif train_dataset_full[i][1] == 6:\n shirt += 1\nRun Code Online (Sandbox Code Playgroud)\n现在,代替sneaker += 1、pullover += 1、sandal += 1我shirt += 1想做类似的事情empty_dataset.append(train_dataset_full[i])或类似的事情。
如果上述方法不正确,请建议其他方法。
\nNur*_*nab 10
终于找到了答案。
dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
# Selecting classes 7, 2, 5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
Run Code Online (Sandbox Code Playgroud)