来自 csv 文件路径和标签的 Pytorch 数据加载器

Cop*_*OfA 2 python pytorch

我有一个用于训练和测试数据集的 csv 文件,其中包含文件位置和标签。该数据框的头部是:

df.head()
Out[46]: 
             file_path  label
0  \\images\\29771.png      0
1  \\images\\55201.png      0
2  \\images\\00715.png      1
3  \\images\\33214.png      0
4  \\images\\99841.png      1
Run Code Online (Sandbox Code Playgroud)

我的文件路径有多个位置,但空间有限,因此无法将它们复制到 \0 和 \1 文件夹位置。如何使用此数据框创建 pytorch 数据加载器和/或数据集对象?

Kar*_*arl 5

只需为您的数据集编写一个自定义__getitem__方法即可。

class MyData(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        image = load_image(self.df.file_path[index])
        label = self.df.label[index]

        return image, label
Run Code Online (Sandbox Code Playgroud)

哪里load_image有一个函数可以将文件名读取为您需要的任何格式。