如何从pytorch中从纪元增长到纪元的高IO数据集读取

Dav*_*rks 9 python pytorch

我使用 Tensorflow,但我正在为用户编写文档,这些文档通常会因深度学习框架而异

当处理不适合本地文件系统 (TB+) 的数据集时,我从远程数据存储中采样数据并将样本在本地写入 Tensorflow 标准tfrecords格式。

在训练的第一个 epoch 期间,我将只采样几个值,因此本地数据的epoch非常小,我对其进行训练。在epoch 2 上,我重新检查我的采样子流程(现在更多)生成了哪些数据文件,并在下一个 epoch 的本地数据文件的扩展集上进行训练。每个时期重复该过程。通过这种方式,我建立了一个本地样本缓存,并可以在我填满本地存储时驱逐旧样本。本地样本缓存大约在模型最需要方差的时间增长(朝向训练的后期)。

在 Python/Tensorflow 中,重要的是我不要在 Python 训练循环过程中反序列化数据,因为 Python GIL 无法支持数据传输速率(300-600 MB/秒,数据是原始科学不可压缩的),因此 GPU 性能当 Python GIL 无法快速为训练循环提供服务时,就会受到影响。

将样本tfrecords从子进程(python 多处理)写入文件允许 tensorflow 的本机TFRecordsDataset在 Python 之外进行反序列化,因此我们避开了 Python GIL 问题,并且我可以使 GPU 达到高 IO 数据速率。

我想知道我将如何在 Pytorch 中解决这个问题。我正在撰写有关正在使用的采样策略的文章,并希望向 Tensorflow 和 PyTorch 的用户提供具体的建议,但我对 PyTorch 预处理生态系统的了解不够深入,无法写出足够详细的内容。

旁注:支持这些数据传输率的唯一基于 Python 的解决方案可能来自 Python 3.8,它带有 System V 共享内存和多处理,但我还没有尝试过,因为对它的支持还不够(很快就会)。现有的多处理解决方案是不够的,因为它们需要在训练循环过程中进行反序列化,因此在反序列化期间以高 IO 速率锁定 GIL。

bom*_*mbs 10

实际上,您可以使用torch.utils.data.DataLoader. 通过将num_workers参数设置为 1 或更大的值,您可以使用自己的 Python 解释器和 GIL 生成子进程。

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point
Run Code Online (Sandbox Code Playgroud)

ADataloader需要 atorch.utils.data.Dataset从中获取数据。在您的情况下实现适当的子类可能不是一件容易的事。如果您需要Dataset为每个 epoch重新创建一个实例,您可以这样做。

for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training
Run Code Online (Sandbox Code Playgroud)

甚至更好

dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training
Run Code Online (Sandbox Code Playgroud)

作为旁注,请注意,在大多数情况下,受 GIL 影响的是 CPU 绑定操作,而不是 I/O 绑定操作,即,threading对于任何纯 I/O 繁重的操作,您甚至不需要subprocess. 有关更多信息,请参阅此问题和这篇维基百科文章