获取pytorch数据集的子集

Mir*_*ber 7 python machine-learning neural-network torch pytorch

我有一个网络,我想在一些数据集上训练(例如,说CIFAR10).我可以通过创建数据加载器对象

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

我的问题如下:假设我想进行几次不同的训练迭代.假设我首先想要在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推.为此,我需要能够访问这些图像.不幸的是,它似乎trainset不允许这种访问.也就是说,尝试做trainset[:1000]或更多一般trainset[mask]会抛出错误.

我可以做

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
Run Code Online (Sandbox Code Playgroud)

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改,trainset.train_data所以我需要重新定义trainset).有没有办法避免它?

理想情况下,我希望有一些"等同"的东西

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

jay*_*elm 55

torch.utils.data.Subset更容易,支持shuffle,并且不需要编写自己的采样器:

import torchvision
import torch

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)

evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)

trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

  • 将“evens”和“odds”转换为列表是没有必要的——至少在 torch 1.5.0 中,“Subset”接受生成器:“ts1 = Subset(trainset, range(0, len(trainset), 2))” (5认同)
  • @user650654 有点偏离主题,但“range”不是生成器。 (5认同)

Man*_*nas 13

您可以为数据集加载器定义自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器).

class YourSampler(Sampler):
    def __init__(self, mask):
        self.mask = mask

    def __iter__(self):
        return (self.indices[i] for i in torch.nonzero(self.mask))

    def __len__(self):
        return len(self.mask)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler2, shuffle=False, num_workers=2)
Run Code Online (Sandbox Code Playgroud)

PS:你可以在这里找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

  • 谢谢!一个小小的评论:显然采样器与shuffle不兼容,所以为了达到相同的结果,可以做:torch.utils.data.DataLoader(trainset,batch_size = 4,sampler = SubsetRandomSampler(np.where(mask)[0 ]),shuffle = False,num_workers = 2) (3认同)