小编lia*_*999的帖子

pytorch中Dataloader、sampler、generator的关系

假设我有一个数据集:

datasets = [0,1,2,3,4]
Run Code Online (Sandbox Code Playgroud)

在场景一中,代码为:

torch.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是1,3,4,0,2

在场景二中,代码为:

torch.manual_seed(1)

seed=1234
G = torch.Generator()
G.manual_seed(seed)

ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler,
                        generator=G)
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是1,3,4,0,2。事实上,给变量赋予任何值seed,结果仍然是1,3,4,0,2

在场景三中,代码为:

torch.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler)
for data in dataloader:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是4,1,3,0,2

我检查了源代码RandomSampler发现:

seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
Run Code Online (Sandbox Code Playgroud)

It shows that RandomSampler would create …

python pytorch

6
推荐指数
1
解决办法
1404
查看次数

标签 统计

python ×1

pytorch ×1