相关疑难解决方法(0)

PyTorch DataLoader 对并行运行的批次使用相同的随机种子

PyTorch/Numpy 中存在一个错误,即当与 a DataLoader(即设置num_workers > 1)并行加载批次时,每个工作线程使用相同的 NumPy 随机种子,导致并行批次之间应用的任何随机函数都是相同的。

最小的例子:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 9
    
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=3)

for batch in dataloader:
    print(batch)
Run Code Online (Sandbox Code Playgroud)

如您所见,对于每个并行批次集 (3),结果是相同的:

# First 3 batches
tensor([[891, 674]])
tensor([[891, 674]])
tensor([[891, 674]])
# Second 3 batches
tensor([[545, 977]])
tensor([[545, 977]])
tensor([[545, 977]])
# Third 3 batches
tensor([[880, 688]])
tensor([[880, 688]])
tensor([[880, …
Run Code Online (Sandbox Code Playgroud)

python parallel-processing numpy pytorch dataloader

6
推荐指数
1
解决办法
9795
查看次数

标签 统计

dataloader ×1

numpy ×1

parallel-processing ×1

python ×1

pytorch ×1