Pytorch DataLoader 中的 num_worker 和 prefetch_factor 无法扩展

kyc*_*c12 7 python-multiprocessing pytorch pytorch-dataloader

与 python 的多处理相关的序列化和反序列化似乎限制了并行处理数据的好处。

在下面的示例中,我创建一个返回 numpy 数组的自定义迭代。随着numpy数组大小的增加,数据获取过程成为瓶颈。这是预料之中的。然而,我预计会增加num_workerprefetch_factor通过提前准备批次来减少这一瓶颈。但我在下面的示例中没有看到这种行为。

我测试了两种MyIterable返回的情况

  1. 小物体np.array((10, 150))
  2. 大物体np.array((1000, 150))

两种情况下处理一个批次的平均时间如下:

# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
Run Code Online (Sandbox Code Playgroud)
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253
Run Code Online (Sandbox Code Playgroud)

对于小物体,每批的时间随着num_workers增加而按预期下降。但对于较大的物体,变化不大。我将其归因于工作进程必须序列化 np 对象,然后主进程将其反序列化。物体越大,花费的时间就越多。

num_worker然而,当和足够大时prefetch_factor,数据加载器中的队列不应该总是被填满,这样数据获取就不会成为瓶颈吗?

而且,改变并prefetch_factor不会改变任何东西。有什么意义prefetch_factor?该文档说主进程预加载num_worker * prefetch_factor批次,但正如您所见,对于减少瓶颈没有任何效果。

我在这个问题中添加了更详细的逐步分析,以供参考。

import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset


def collate_fn(records):
    # some custom collation function
    return records


class MyIterable(object):
    def __init__(self, n):
        self.n = n
        self.i = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.n:
            sleep(0.003125)  # simulates data fetch time
            # return np.random.random((10, 150))  # small data item
            return np.random.random((1000, 150))  # large data item
        else:
            raise StopIteration


class MyIterableDataset(IterableDataset):
    def __init__(self, n):
        super(MyIterableDataset).__init__()
        self.n = n

    def __iter__(self):
        return MyIterable(self.n)


def get_performance_metrics(num_workers):
    ds = MyIterableDataset(n=10000)

    if num_workers == 0:
        dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
    else:
        dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
                                         batch_size=128, collate_fn=collate_fn,
                                         multiprocessing_context='spawn')
    warmup = 5
    times = []
    t0 = time.perf_counter()
    for i, batch in enumerate(dl):
        sleep(0.05)  # simulates train step
        e = time.perf_counter()
        if i >= warmup:
            times.append(e - t0)
        t0 = time.perf_counter()
        if i >= 20:
            break
    print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')


if __name__ == '__main__':
    num_worker_options = [0, 2, 4, 6, 8]
    for n in num_worker_options:
        get_performance_metrics(n)
Run Code Online (Sandbox Code Playgroud)