使用 DataLoaders 在 PyTorch 中验证数据集

qal*_*lis 5 neural-network pytorch

我想在 PyTorch 和 Torchvision 中加载 MNIST 数据集,将其分为训练、验证和测试部分。到目前为止,我有:

def load_dataset():
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/', train=True, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()])),
        batch_size=batch_size_train, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/', train=False, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()])),
        batch_size=batch_size_test, shuffle=True)
Run Code Online (Sandbox Code Playgroud)

如果训练数据集在DataLoader. 我想使用训练数据集中的最后 10000 个示例作为验证数据集(我知道我应该做 CV 以获得更准确的结果,我只想在这里快速验证)。

qal*_*lis 13

在 PyTorch 中将训练数据集拆分为训练和验证实际上比应有的要困难得多。

首先,将训练集拆分为训练和验证子集(class Subset),它们不是数据集 (class Dataset):

train_subset, val_subset = torch.utils.data.random_split(
        train, [50000, 10000], generator=torch.Generator().manual_seed(1))
Run Code Online (Sandbox Code Playgroud)

然后从这些数据集中获取实际数据:

X_train = train_subset.dataset.data[train_subset.indices]
y_train = train_subset.dataset.targets[train_subset.indices]

X_val = val_subset.dataset.data[val_subset.indices]
y_val = val_subset.dataset.targets[val_subset.indices]
Run Code Online (Sandbox Code Playgroud)

请注意,这样我们就没有对象Dataset,因此我们不能使用DataLoader对象进行批量训练。如果您想使用 DataLoaders,它们可以直接与子集一起使用:

train_loader = DataLoader(dataset=train_subset, shuffle=True, batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=val_subset, shuffle=False, batch_size=BATCH_SIZE)
Run Code Online (Sandbox Code Playgroud)

  • `val_loader` 应该采用 `val_subset` 作为数据集参数,而不是 `train_subset`,对吗? (2认同)