PyTorch DataLoader 如何与 PyTorch 数据集交互以转换批次?

roc*_*ves 9 python pytorch pytorch-dataloader

我正在为 NLP 相关任务创建自定义数据集。

在 PyTorch自定义 datast 教程中,我们看到该__getitem__()方法在返回样本之前为转换留出了空间:

def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
       
        ### SOME DATA MANIPULATION HERE ###

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)

        return sample
Run Code Online (Sandbox Code Playgroud)

然而,这里的代码:

        if torch.is_tensor(idx):
            idx = idx.tolist()
Run Code Online (Sandbox Code Playgroud)

意味着应该能够一次检索多个项目,这让我想知道:

  1. 这种转换如何作用于多个项目?以教程中的自定义转换为例。它们看起来无法在一次调用中应用于一批样本。

  2. 相关的是,如果转换只能应用于单个样本,DataLoader 如何并行检索一批多个样本并应用所述转换?

Joh*_*tud 6

  1. 这种转换如何作用于多个项目?他们通过使用数据加载器处理多个项目。通过使用转换,您可以指定单次数据发射应该发生什么(例如,batch_size=1)。数据加载器接受您的指定batch_size并调用n火炬__getitem__数据集中的方法,将转换应用于发送到训练/验证的每个样本。然后,它将n样本整理到数据加载器发出的批量大小中。

  2. 相关的是,如果转换只能应用于单个样本,DataLoader 如何并行检索一批多个样本并应用所述转换?希望以上内容对您有意义。并行化是由 torch 数据集类和数据加载器完成的,您可以在其中指定num_workers. Torch 将腌制数据集并将其传播给工作人员。