kyc*_*c12 7 python-multiprocessing pytorch pytorch-dataloader
与 python 的多处理相关的序列化和反序列化似乎限制了并行处理数据的好处。
在下面的示例中,我创建一个返回 numpy 数组的自定义迭代。随着numpy数组大小的增加,数据获取过程成为瓶颈。这是预料之中的。然而,我预计会增加num_worker并prefetch_factor通过提前准备批次来减少这一瓶颈。但我在下面的示例中没有看到这种行为。
我测试了两种MyIterable返回的情况
np.array((10, 150))np.array((1000, 150))两种情况下处理一个批次的平均时间如下:
# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
Run Code Online (Sandbox Code Playgroud)
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253
Run Code Online (Sandbox Code Playgroud)
对于小物体,每批的时间随着num_workers增加而按预期下降。但对于较大的物体,变化不大。我将其归因于工作进程必须序列化 np 对象,然后主进程将其反序列化。物体越大,花费的时间就越多。
num_worker然而,当和足够大时prefetch_factor,数据加载器中的队列不应该总是被填满,这样数据获取就不会成为瓶颈吗?
而且,改变并prefetch_factor不会改变任何东西。有什么意义prefetch_factor?该文档说主进程预加载num_worker * prefetch_factor批次,但正如您所见,对于减少瓶颈没有任何效果。
我在这个问题中添加了更详细的逐步分析,以供参考。
import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset
def collate_fn(records):
# some custom collation function
return records
class MyIterable(object):
def __init__(self, n):
self.n = n
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self.n:
sleep(0.003125) # simulates data fetch time
# return np.random.random((10, 150)) # small data item
return np.random.random((1000, 150)) # large data item
else:
raise StopIteration
class MyIterableDataset(IterableDataset):
def __init__(self, n):
super(MyIterableDataset).__init__()
self.n = n
def __iter__(self):
return MyIterable(self.n)
def get_performance_metrics(num_workers):
ds = MyIterableDataset(n=10000)
if num_workers == 0:
dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
else:
dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
batch_size=128, collate_fn=collate_fn,
multiprocessing_context='spawn')
warmup = 5
times = []
t0 = time.perf_counter()
for i, batch in enumerate(dl):
sleep(0.05) # simulates train step
e = time.perf_counter()
if i >= warmup:
times.append(e - t0)
t0 = time.perf_counter()
if i >= 20:
break
print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')
if __name__ == '__main__':
num_worker_options = [0, 2, 4, 6, 8]
for n in num_worker_options:
get_performance_metrics(n)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1964 次 |
| 最近记录: |