Pytorch 中正确的数据加载、分割和扩充

Bra*_*ion 3 neural-network pytorch

该教程似乎没有解释我们应该如何加载、拆分和进行适当的增强。

让我们有一个由汽车和猫组成的数据集。文件夹结构为:

data
  cat
    0101.jpg
    0201.jpg
    ...
  dogs
    0101.jpg
    0201.jpg
    ...
Run Code Online (Sandbox Code Playgroud)

首先,我通过 datasets.ImageFolder 函数加载数据集。图像函数有命令“TRANSFORM”,我们可以在其中设置一些增强命令,但我们不想将增强应用于测试数据集!所以让我们继续使用transform=None。

data = datasets.ImageFolder(root='data')
Run Code Online (Sandbox Code Playgroud)

显然,我们没有文件夹结构训练和测试,因此我认为一个好的方法是使用split_dataset 函数

    train_size = int(split * len(data))
    test_size = len(data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
Run Code Online (Sandbox Code Playgroud)

现在让我们按以下方式加载数据。

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=8,
                                              shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=True)
Run Code Online (Sandbox Code Playgroud)

如何将转换(数据增强)应用于“train_loader”图像?

基本上我需要: 1. 从上面解释的文件夹结构加载数据 2. 将数据拆分为测试/训练部分 3. 在训练部分应用增强。

Ber*_*iel 5

我不确定是否有推荐的方法来执行此操作,但这就是我解决此问题的方法:

鉴于torch.utils.data.random_split()返回Subset,我们不能(我们可以吗?这里不是100%确定我仔细检查了,我们不能)利用它们的内部数据集,因为它们是相同的(唯一的区别在于索引)。在这种情况下,我将实现一个简单的类来应用转换,如下所示:

from torch.utils.data import Dataset

class ApplyTransform(Dataset):
    """
    Apply transformations to a Dataset

    Arguments:
        dataset (Dataset): A Dataset that returns (sample, target)
        transform (callable, optional): A function/transform to be applied on the sample
        target_transform (callable, optional): A function/transform to be applied on the target

    """
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        # yes, you don't need these 2 lines below :(
        if transform is None and target_transform is None:
            print("Am I a joke to you? :)")

    def __getitem__(self, idx):
        sample, target = self.dataset[idx]
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.dataset)
Run Code Online (Sandbox Code Playgroud)

然后在将数据集传递到数据加载器之前使用它:

import torchvision.transforms as transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
    # ...
])
train_dataset = ApplyTransform(train_dataset, transform=train_transform)

# continue with DataLoaders...
Run Code Online (Sandbox Code Playgroud)