在pytorch中加载csv和图像数据集

rts*_*rts 3 image machine-learning computer-vision deep-learning pytorch

我正在使用 PyTorch 进行图像分类。我有一个单独的 Images 文件夹,并使用图像 ids 和 labels 训练和测试 csv 文件。我不知道如何组合这些图像和 ID 并转换为张量。

\n
    \n
  1. train.csv :包含图像的所有 ID,如 4325.jpg、2345.jpg、\xe2\x80\xa6so 以及包含猫、狗等标签。
  2. \n
  3. Image_data :包含具有 ID 名称的所有图像。
  4. \n
\n

Mit*_*iku 8

您可以通过继承 pytorch 的torch.utils.data.Dataset来创建自定义数据集类。

以下自定义数据集类的假设是

  • csv 文件格式为

文件名 标签
4325.jpg
2345.jpg
  • 所有图像都在里面images folder
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_folder, transform = None):
        self.df = pd.read_csv(csv_path)
        self.images_folder = images_folder
        self.transform = transform
        self.class2index = {"cat":0, "dog":1}

    def __len__(self):
        return len(self.df)
    def __getitem__(self, index):
        filename = self.df[index, "FILENAME"]
        label = self.class2index[self.df[index, "LABEL"]]
        image = PIL.Image.open(os.path.join(self.images_folder, filename))
        if self.transform is not None:
            image = self.transform(image)
        return image, label
        
Run Code Online (Sandbox Code Playgroud)

现在,您可以使用此类通过 csv 文件和图像文件夹加载训练和测试数据集。


train_dataset = CustomDataset("path - to - train.csv", "path - to - images - folder"  )
test_dataset = CustomDataset("path - to - test.csv", "path - to - images - folder"  )


image, label = train_dataset[0]
Run Code Online (Sandbox Code Playgroud)

  • 很好的答案 - 对我来说唯一的问题是 `df[index, "SDFSDF"]` 引发了错误 - 相反,我使用了 `df.loc[index]["SDFSDF"]` 并且工作正常。 (3认同)