向 pytorch 数据加载器/数据集添加自定义标签不适用于自定义数据集

Leo*_*ewk 4 python machine-learning computer-vision pytorch

我正在 Kaggle 上进行仙人掌图像竞赛,我正在尝试将 PyTorch 数据加载器用于我的 CNN。但是,我遇到了无法为训练集设置标签的问题。训练集图像在文件夹中给出,标签在 csv 文件中。这是我的代码。

 train = torchvision.datasets.ImageFolder(root='../input/train', 
 transform=transform)

 train.targets = torch.from_numpy(df['has_cactus'].values)

 train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)

 for i, data in enumerate(train_loader, 0):
     print(data[1])
Run Code Online (Sandbox Code Playgroud)

此代码输出全为零的批处理张量,这显然是不正确的,因为绝大多数标签(如果您要查看数据帧)都是 1。我相信这是将标签分配给“train.targets”的问题。如果在分配其他标签之前打印“train.targets”,它会返回一个全为零的张量,这与我得到的错误结果一致。我该如何解决这个问题?

小智 7

我通常继承内置的 DataSet 类,如下所示:

from torch.utils.data import DataLoader
class DataSet:

    def __init__(self, root):
        """Init function should not do any heavy lifting, but
            must initialize how many items are available in this data set.
        """

        self.ROOT = root
        self.images = read_images(root + "/images")
        self.labels = read_labels(root + "/labels")

    def __len__(self):
        """return number of points in our dataset"""

        return len(self.images)

    def __getitem__(self, idx):
        """ Here we have to return the item requested by `idx`
            The PyTorch DataLoader class will use this method to make an iterable for
            our training or validation loop.
        """

        img = images[idx]
        label = labels[idx]

        return img, label
Run Code Online (Sandbox Code Playgroud)

现在,您可以创建此类的一个实例,

ds = Dataset('../input/train')
Run Code Online (Sandbox Code Playgroud)

现在,您可以实例化 DataLoader:

dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
Run Code Online (Sandbox Code Playgroud)

这将创建您可以访问的批量数据:

for image, label in dl:
    print(label)
Run Code Online (Sandbox Code Playgroud)

  • 谢谢您的答复。您如何建议我在代码中实现“read_images”方法? (2认同)