Ryl*_*fer 7 pytorch pytorch-lightning
我有一个可以创建合成数据的生成器。如何将其转换为 PyTorch 数据加载器?
您可以用以下内容包装您的生成器data.IterableDataset:
class IterDataset(data.IterableDataset):
def __init__(self, generator):
self.generator = generator
def __iter__(self):
return self.generator()
Run Code Online (Sandbox Code Playgroud)
当然,您可以用data.DataLoader.
这是一个展示其用途的最小示例:
>>> gen = lambda: [(yield x) for x in range(10)]
>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
... print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])
Run Code Online (Sandbox Code Playgroud)