PyTorch DataLoader 在每个时期使用相同的随机转换

uke*_*emi 3 python parallel-processing numpy pytorch dataloader

PyTorch/Numpy 中存在一个错误,即当与 a DataLoader(即设置num_workers > 1)并行加载批次时,每个工作线程使用相同的 NumPy 随机种子,导致并行批次之间应用的任何随机函数都是相同的。这可以通过将种子生成器传递给worker_init_fn参数来解决,如下所示

然而,这个问题在多个时期仍然存在。

最小的例子:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 4

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, 
                        num_workers=2, 
                        worker_init_fn = lambda x: np.random.seed(x))

for epoch in range(3):
    print(f'\nEpoch {epoch}')
    for batch in dataloader:
        print(batch)
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,虽然一个时期内的并行批次现在会产生不同的结果,但跨时期的结果是相同的:

Epoch 0
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])

Epoch 1
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])

Epoch 2
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])
Run Code Online (Sandbox Code Playgroud)

如何解决这种行为?


使用空参数例如worker_init_fn = lambda _: np.random.seed()似乎可以解决此问题 - 此解决方法有任何问题吗?

Tu *_*Bui 5

我能想到的最好的方法是使用 pytorch 为 numpy 和 random 设置的种子:

import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    if torch_seed >= 2**30:  # make sure torch_seed + workder_id < 2**32
        torch_seed = torch_seed % 2**30
    np.random.seed(torch_seed + worker_id)

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 4

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, 
                        num_workers=2, 
                        worker_init_fn = worker_init_fn)

for epoch in range(3):
    print(f'\nEpoch {epoch}')
    for batch in dataloader:
        print(batch)
Run Code Online (Sandbox Code Playgroud)

输出:

Epoch 0
tensor([[593, 191]])
tensor([[207, 469]])
tensor([[976, 714]])
tensor([[ 13, 119]])

Epoch 1
tensor([[836, 664]])
tensor([[138, 836]])
tensor([[409, 313]])
tensor([[  2, 221]])

Epoch 2
tensor([[269, 888]])
tensor([[315, 619]])
tensor([[892, 774]])
tensor([[ 70, 771]])
Run Code Online (Sandbox Code Playgroud)

或者,您可以使用int(time.time())播种numpyrandom,假设每个纪元的运行时间超过 1 秒。