如何将生成器转换为 Pytorch 数据加载器?

Ryl*_*fer 7 pytorch pytorch-lightning

我有一个可以创建合成数据的生成器。如何将其转换为 PyTorch 数据加载器?

Iva*_*van 9

您可以用以下内容包装您的生成器data.IterableDataset

class IterDataset(data.IterableDataset):
    def __init__(self, generator):
        self.generator = generator

    def __iter__(self):
        return self.generator()
Run Code Online (Sandbox Code Playgroud)

当然,您可以用data.DataLoader.

这是一个展示其用途的最小示例:

>>> gen = lambda: [(yield x) for x in range(10)]

>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
...    print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])
Run Code Online (Sandbox Code Playgroud)