uke*_*emi 3 python parallel-processing numpy pytorch dataloader
PyTorch/Numpy 中存在一个错误,即当与 a DataLoader(即设置num_workers > 1)并行加载批次时,每个工作线程使用相同的 NumPy 随机种子,导致并行批次之间应用的任何随机函数都是相同的。这可以通过将种子生成器传递给worker_init_fn参数来解决,如下所示。
然而,这个问题在多个时期仍然存在。
最小的例子:
import numpy as np
from torch.utils.data import Dataset, DataLoader
class RandomDataset(Dataset):
def __getitem__(self, index):
return np.random.randint(0, 1000, 2)
def __len__(self):
return 4
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1,
num_workers=2,
worker_init_fn = lambda x: np.random.seed(x))
for epoch in range(3):
print(f'\nEpoch {epoch}')
for batch in dataloader:
print(batch)
Run Code Online (Sandbox Code Playgroud)
正如您所看到的,虽然一个时期内的并行批次现在会产生不同的结果,但跨时期的结果是相同的:
Epoch 0
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908, 72]])
Epoch 1
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908, 72]])
Epoch 2
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908, 72]])
Run Code Online (Sandbox Code Playgroud)
如何解决这种行为?
使用空参数例如worker_init_fn = lambda _: np.random.seed()似乎可以解决此问题 - 此解决方法有任何问题吗?
我能想到的最好的方法是使用 pytorch 为 numpy 和 random 设置的种子:
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
def worker_init_fn(worker_id):
torch_seed = torch.initial_seed()
random.seed(torch_seed + worker_id)
if torch_seed >= 2**30: # make sure torch_seed + workder_id < 2**32
torch_seed = torch_seed % 2**30
np.random.seed(torch_seed + worker_id)
class RandomDataset(Dataset):
def __getitem__(self, index):
return np.random.randint(0, 1000, 2)
def __len__(self):
return 4
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1,
num_workers=2,
worker_init_fn = worker_init_fn)
for epoch in range(3):
print(f'\nEpoch {epoch}')
for batch in dataloader:
print(batch)
Run Code Online (Sandbox Code Playgroud)
输出:
Epoch 0
tensor([[593, 191]])
tensor([[207, 469]])
tensor([[976, 714]])
tensor([[ 13, 119]])
Epoch 1
tensor([[836, 664]])
tensor([[138, 836]])
tensor([[409, 313]])
tensor([[ 2, 221]])
Epoch 2
tensor([[269, 888]])
tensor([[315, 619]])
tensor([[892, 774]])
tensor([[ 70, 771]])
Run Code Online (Sandbox Code Playgroud)
或者,您可以使用int(time.time())播种numpy和random,假设每个纪元的运行时间超过 1 秒。