pytorch中Dataloader、sampler、generator的关系

lia*_*999 6 python pytorch

假设我有一个数据集:

datasets = [0,1,2,3,4]
Run Code Online (Sandbox Code Playgroud)

在场景一中,代码为:

torch.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是1,3,4,0,2

在场景二中,代码为:

torch.manual_seed(1)

seed=1234
G = torch.Generator()
G.manual_seed(seed)

ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler,
                        generator=G)
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是1,3,4,0,2。事实上,给变量赋予任何值seed,结果仍然是1,3,4,0,2

在场景三中,代码为:

torch.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler)
for data in dataloader:
  print(data)
Run Code Online (Sandbox Code Playgroud)

结果是4,1,3,0,2

我检查了源代码RandomSampler发现:

seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
Run Code Online (Sandbox Code Playgroud)

It shows that RandomSampler would create a generator itself if no generator is given. Therefore in theory, my scenario I, II and III would output the same results but scenario III outputs a different results. Why would this happen? I am lost in the source code of Dataloader and I am confused about the relationship between Dataloader, sampler and generator.

I have already asked a question about The shuffling order of DataLoader in pytorch. I understand that Dataloader would pass the generator to sampler in certain environments but in my scenario III, the RandomSampler has a generator already.

kmk*_*urn 5

场景 3 的调查实际上比我想象的要困难。让我们一一看看所有场景。在这个答案中,“生成器”的意思是“随机数生成器”,它是 的实例torch.Generator,而不是 Python 的生成器

场景1

这个场景很简单。当一个人RandomSampler在没有generator提供的情况下迭代创建时,采样器会创建自己的生成器,正如您所指出的。这个创建可以在定义中看到RandomSampler.__iter__

seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
Run Code Online (Sandbox Code Playgroud)

结果是1,3,4,0,2

场景2

传递给的生成器DataLoader仅用于(a)创建一个RandomSampler如果sampler未给出的情况以及(b)在使用多处理时为工作人员生成基本种子。文档字符串中描述了这两种用途。在您的代码中,sampler设置为ran_sampler并且不使用多处理(默认值)。因此,传递的生成器G没有任何作用。其目的可能是G确定 的随机抽样datasets。在这种情况下,G应传递给RandomSampler如下

seed = 1
G = torch.Generator()
G.manual_seed(seed)

ran_sampler = RandomSampler(data_source=datasets, generator=G)  # G is passed here
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler)  # G is not passed here
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

但是,此代码的打印结果0,4,2,3,1与场景 1 不同。这是因为随机采样器在场景 1 中使用的实际种子不是 1(回想一下,RandomSampler在场景 1 中创建了自己的生成器)。为了使输出相同,我们需要使用相同的种子:

torch.manual_seed(1)

seed = int(torch.empty((), dtype=torch.int64).random_().item())  # use the same seed as Scenario 1
G = torch.Generator()
G.manual_seed(seed)

ran_sampler = RandomSampler(data_source=datasets, generator=G)
dataloader = DataLoader(dataset=datasets, 
                        sampler=ran_sampler)
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

此代码现在输出1,3,4,0,2.

场景3

在场景 3 中,人们会期望结果与场景 1 相同,因为它们使用相同的随机采样器,并且两个采样器都应该为自己的生成器生成相同的种子。然而,种子实际上是不同的,因为当DataLoader.__iter__被调用时,下面的代码(_BaseDataLoaderItertorch.utils.data.dataloader的类内部定义)也将运行:

self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
Run Code Online (Sandbox Code Playgroud)

这里,self._index_sampler是一个BatchSampler迭代ran_sampler ifself._sampler_iter实例。换句话说,为了ran_sampler创建自己的生成器,self._sampler_iter必须进行迭代。看完代码后BatchSampler.__iter__,您可能想知道为什么。原因是因为self._index_sampler.__iter__一个 Python 生成器,仅当返回的生成器迭代器(即 )被迭代时才会执行self._sampler_iter,这在上面的代码中不会发生。

请注意,Python 生成器不是一个随机数生成器torch.Generator,而是一个包含yield. 不幸的是,两者使用相同的术语,这可能会引起混乱。

另请注意,上面的代码 ( self._base_seed) 中生成了一个种子,该种子发生在迭代之前。 self._sampler_iter当场景 3 中的循环运行(即self._sampler_iter迭代)时,ran_sampler会创建自己的生成器。回想一下场景 1,此创建执行

seed = int(torch.empty((), dtype=torch.int64).random_().item())
Run Code Online (Sandbox Code Playgroud)

其中该调用Tensor.random_()是该方法的第二次调用,seed与场景1不同;现在,场景 1 的种子是通过第一次调用self._base_seed获得的。换句话说,等于场景 1 种子,而不是。场景 1 中的迭代两次提供了以下指示:Tensor.random_()self._base_seedseedran_sampler

torch.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
list(ran_sampler)  # first iteration
for data in ran_sampler:
  print(data)
Run Code Online (Sandbox Code Playgroud)

上面的代码输出4,1,3,0,2.