使用特定元素自定义批次

6 pytorch pytorch-dataloader

我是 PyTorch 的新手。奇怪的是,我找不到与此相关的任何内容,尽管它看起来很简单。

我想用特定的示例来构建我的批次,例如每批次的所有示例都具有相同的标签,或者只是用仅 2 个类的示例填充批次。

我该怎么做呢?对我来说,这似乎是数据加载器中的正确位置,而不是数据集中的正确位置?由于数据加载器负责批次而不是数据集?

有一个简单的最小示例吗?

Iva*_*van 20

太长了;

  1. 默认DataLoader仅使用采样器,而不是批量采样器。

  2. 可以定义一个采样器,再加上一个批量采样器,批量采样器会覆盖采样器。

  3. 采样器仅产生数据集元素的序列,而不是实际的批次(这由数据加载器处理,具体取决于batch_size)。


回答您最初的问题:在可迭代数据集上使用采样器似乎是不可能的 Github 问题(仍然开放)。另请阅读以下关于 的注释pytorch/dataloader.py


采样器(用于地图样式数据集):

除此之外,如果您要切换到地图样式数据集,这里有一些有关采样器和批量采样器如何工作的详细信息。您可以使用索引访问数据集的基础数据,就像使用列表一样因为torch.utils.data.DatasetImplements __getitem__)。换句话说,您的数据集元素都是dataset[i], for iin [0, len(dataset) - 1]

这是一个玩具数据集:

class DS(Dataset):
    def __getitem__(self, index):
        return index
        
    def __len__(self):
        return 10
Run Code Online (Sandbox Code Playgroud)

在一般用例中,您只需给出torch.utils.data.DataLoader参数batch_sizeshuffle。默认情况下,shuffle设置为false,这意味着它将使用torch.utils.data.SequentialSampler. 否则(如果shuffletruetorch.utils.data.RandomSampler将被使用。采样器定义数据加载器如何访问数据集(访问的顺序)。

上述数据集 ( DS) 有10 个元素。索引为0, 1, 2, 3, 4, 5, 6, 7, 8, 和9。它们映射到元素0, 10, 20, 30, 40, 50, 60, 70, 80, 和90。因此,批量大小为2

  • SequentialSampler:(DataLoader(ds, batch_size=2)隐式shuffle=False),与 相同DataLoader(ds, batch_size=2, sampler=SequentialSampler(ds))。数据加载器将提供tensor([0, 10])tensor([20, 30])tensor([40, 50])tensor([60, 70])tensor([80, 90])

  • RandomSamplerDataLoader(ds, batch_size=2, shuffle=True), 相同DataLoader(ds, batch_size=2, sampler=RandomSampler(ds))每次迭代时,数据加载器都会随机采样。例如:tensor([50, 40])tensor([90, 80])tensor([0, 60])tensor([10, 20])、 和tensor([30, 70])。但是如果您第二次迭代数据加载器,顺序将会不同!


批量采样器

提供batch_sampler将完全覆盖 batch_sizeshufflesamplerdrop_last。它旨在准确定义批处理元素及其内容。例如:

>>> DataLoader(ds, batch_sampler=[[1,2,3], [6,5,4], [7,8], [0,9]])` 
Run Code Online (Sandbox Code Playgroud)

将产生tensor([10, 20, 30])tensor([60, 50, 40])tensor([70, 80])tensor([ 0, 90])


班级批量抽样

假设我只想在批次中每个类中有 2 个元素(不同或不同),并且必须排除每个类的更多示例。因此请确保批次内不包含 3 个示例。

假设您有一个包含四个类的数据集。我将这样做。首先,跟踪每个类别的数据集索引。

class DS(Dataset):
    def __init__(self, data):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(4)]
        for i, x in enumerate(data):
            if x > 0 and x % 2: self.indices[0].append(i)
            if x > 0 and not x % 2: self.indices[1].append(i)
            if x < 0 and x % 2: self.indices[2].append(i)
            if x < 0 and not x % 2: self.indices[3].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]
Run Code Online (Sandbox Code Playgroud)

例如:

>>> ds = DS([1, 6, 7, -5, 10, -6, 8, 6, 1, -3, 9, -21, -13, 11, -2, -4, -21, 4])
Run Code Online (Sandbox Code Playgroud)

会给:

>>> ds.classes()
[[0, 2, 8, 10, 13], [1, 4, 6, 7, 17], [3, 9, 11, 12, 16], [5, 14, 15]]
Run Code Online (Sandbox Code Playgroud)

然后,对于批量采样器,最简单的方法是创建可用的类索引列表,并且具有与数据集元素一样多的类索引。

在上面定义的数据集中,我们有5 个来自 class 的项目05 个来自 class 的项目15 个来自 class 的项目2,以及3 个来自 class 的项目3。因此我们想要构建[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]. 我们将对其进行洗牌。然后,根据此列表和数据集类内容 ( ds.classes()),我们将能够构建批次。

class Sampler():
    def __init__(self, classes):
        self.classes = classes

    def __iter__(self):
        classes = copy.deepcopy(self.classes)
        indices = flatten([[i for _ in range(len(klass))] for i, klass in enumerate(classes)])
        random.shuffle(indices)
        grouped = zip(*[iter(indices)]*2)

        res = []
        for a, b in grouped:
            res.append((classes[a].pop(), classes[b].pop()))
        return iter(res)
Run Code Online (Sandbox Code Playgroud)

注意- 需要深度复制列表,因为我们要从中弹出元素。

该采样器的可能输出是:

[(15, 14), (16, 17), (7, 12), (11, 6), (13, 10), (5, 4), (9, 8), (2, 0), (3, 1)]
Run Code Online (Sandbox Code Playgroud)

此时我们可以简单地使用torch.data.utils.DataLoader

>>> dl = DataLoader(ds, batch_sampler=sampler(ds.classes()))
Run Code Online (Sandbox Code Playgroud)

这可能会产生类似的结果:

[tensor([ 4, -4]), tensor([-21,  11]), tensor([-13,   6]), tensor([9, 1]), tensor([  8, -21]), tensor([-3, 10]), tensor([ 6, -2]), tensor([-5,  7]), tensor([-6,  1])]
Run Code Online (Sandbox Code Playgroud)

更简单的方法

这是另一种更简单的方法,它不能保证从数据集中返回所有元素,平均而言它会......

对于每个批次,首先对class_per_batch类进行采样,然后batch_size从这些选定的类中对元素进行采样(首先从该类子集中采样一个类,然后从该类的数据点中采样)。

class Sampler():
    def __init__(self, classes, class_per_batch, batch_size):
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.class_per_batch = class_per_batch
        self.batch_size = batch_size

    def __iter__(self):
        classes = random.sample(range(len(self.classes)), self.class_per_batch)
        
        batches = []
        for _ in range(self.n_batches):
            batch = []
            for i in range(self.batch_size):
                klass = random.choice(classes)
                batch.append(random.choice(self.classes[klass]))
            batches.append(batch)
        return iter(batches)
Run Code Online (Sandbox Code Playgroud)

你可以这样尝试:

>>> s = Sampler(ds.classes(), class_per_batch=2, batch_size=4)
>>> list(s)
[[16, 0, 0, 9], [10, 8, 11, 2], [16, 9, 16, 8], [2, 9, 2, 3]]

>>> dl = DataLoader(ds, batch_sampler=s)
>>> list(iter(dl))
[tensor([ -5,  -6, -21, -13]), tensor([ -4,  -4, -13, -13]), tensor([ -3, -21,  -2,  -5]), tensor([-3, -5, -4, -6])]
Run Code Online (Sandbox Code Playgroud)