在PyTorch中实现"无限循环"数据集和数据加载器

Cov*_*ovi 7 pytorch

我想实现一个无限循环的Dataset和DataLoader.这是我试过的:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"
Run Code Online (Sandbox Code Playgroud)

如您所见,这里的主要挑战是__len()__方法.如果我在那里放一个足够大的数字,比如1 << 30,那么在火车循环的第一次迭代中,内存使用的症状就是JUMP TO 10 + GB.过了一会儿,工人们被杀,大概是因为OOM.

如果我在那里放一个小数字,如1或BATCH_SIZE,则会定期复制列车循环中的采样"数据".这不是我想要的,因为我希望在每次迭代时生成和训练新数据.

我猜测过多的内存使用的罪魁祸首是堆栈中的某个地方,一堆东西被缓存.随便看看Python方面的事情,我无法确定在哪里.

有人可以建议什么是我想要实现的最佳方式?(使用DataLoader的并行加载,同时保证每个加载的批次都是全新的.)

Jat*_*aki 1

DataLoader对数据集进行采样而不进行替换。为此,它会生成0 到 之间的索引的随机排列len(dataset)。我猜这种排列会耗尽你的大部分记忆。我认为 PyTorch API 不支持无限集合,但您可以尝试分叉代码DataLoader并自己完成。您可以使用该batch_sampler参数,并传入一个基于RandomSampler. 这将允许您保留 的并行加载部分DataLoader

话虽这么说,基于 的迭代协议__len__并不__getitem__适合无限集合。您可能最好重新实现Dataset.__len__只 return 1Dataset.__getitem__无论索引如何,始终返回新样本,然后从该数据集中替换n采样时间。从技术上讲,它会询问第 0 个样本的时间,但由于您重写以返回不同的样本,因此这将有效地完成您正在寻找的操作。n__getitem__