我想实现一个无限循环的Dataset和DataLoader.这是我试过的:
class Infinite(Dataset):
def __len__(self):
return HPARAMS.batch_size
# return 1<<30 # This causes huge memory usage.
def __getitem__(self, idx):
"""Randomly generates one new example."""
return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
dataset=Infinite(),
batch_size=HPARAMS.batch_size,
num_workers=16,
worker_init_fn=lambda worker_id: np.random.seed(worker_id),
)
while True:
for idx, data in enumerate(infinite_loader):
# forward + backward on "data"
Run Code Online (Sandbox Code Playgroud)
如您所见,这里的主要挑战是__len()__
方法.如果我在那里放一个足够大的数字,比如1 << 30,那么在火车循环的第一次迭代中,内存使用的症状就是JUMP TO 10 + GB.过了一会儿,工人们被杀,大概是因为OOM.
如果我在那里放一个小数字,如1或BATCH_SIZE,则会定期复制列车循环中的采样"数据".这不是我想要的,因为我希望在每次迭代时生成和训练新数据.
我猜测过多的内存使用的罪魁祸首是堆栈中的某个地方,一堆东西被缓存.随便看看Python方面的事情,我无法确定在哪里.
有人可以建议什么是我想要实现的最佳方式?(使用DataLoader的并行加载,同时保证每个加载的批次都是全新的.)
DataLoader
对数据集进行采样而不进行替换。为此,它会生成0 到 之间的索引的随机排列len(dataset)
。我猜这种排列会耗尽你的大部分记忆。我认为 PyTorch API 不支持无限集合,但您可以尝试分叉代码DataLoader
并自己完成。您可以使用该batch_sampler
参数,并传入一个基于RandomSampler
. 这将允许您保留 的并行加载部分DataLoader
。
话虽这么说,基于 的迭代协议__len__
并不__getitem__
适合无限集合。您可能最好重新实现Dataset.__len__
只 return 1
,Dataset.__getitem__
无论索引如何,始终返回新样本,然后从该数据集中替换n
采样时间。从技术上讲,它会询问第 0 个样本的时间,但由于您重写以返回不同的样本,因此这将有效地完成您正在寻找的操作。n
__getitem__
归档时间: |
|
查看次数: |
1045 次 |
最近记录: |