如何在pytorch MNIST数据集中选择特定标签

Aym*_*ass 7 python pytorch

我正在尝试仅使用 PyTorch Mnist 数据集中的特定数字创建数据加载器

我已经尝试创建自己的采样器,但它不起作用,而且我不确定我是否正确使用了蒙版。

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, mask):

        self.mask = mask


    def __iter__(self):

        return (self.indices[i] for i in torch.nonzero(self.mask))


    def __len__(self):

        return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)   

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)   

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)

Run Code Online (Sandbox Code Playgroud)

到目前为止,我遇到了许多不同类型的错误。对于此实现,它是“停止迭代”。我觉得这很简单/愚蠢,但我找不到一个简单的方法来做到这一点。感谢您的帮助!

Gep*_*o97 10

我能想到的最简单的选择是就地减少数据集:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]
Run Code Online (Sandbox Code Playgroud)

  • `索引 = (数据集.目标 == 5) | (数据集.targets == 6) | (dataset.targets == 7)` 应该这样做 (6认同)

Aym*_*ass 1

感谢您的帮助。一段时间后,我找到了一个解决方案(但可能根本不是最好的):

class YourSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, mask, data_source):
        self.mask = mask
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(mask)])

    def __len__(self):
        return len(self.data_source)

mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)    
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)   
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)
Run Code Online (Sandbox Code Playgroud)