小编AIC*_*oda的帖子

迭代 torch.utils.data.random_split 的子集

我目前正在加载一个包含人工智能训练数据的文件夹。子文件夹代表标签名称以及内部相应的图像。使用 pyTorch 的 ImageFolder 加载器可以很好地实现这一点。

def load_dataset():
    data_path = 'C:/example_folder/'

    train_dataset_manual = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )

    train_loader_manual = torch.utils.data.DataLoader(
        train_dataset_manual,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )

    return train_loader_manual

full_dataset = load_dataset()
Run Code Online (Sandbox Code Playgroud)

现在我想将此数据集分为训练数据集和测试数据集。我为此使用 random_split 函数:

training_data_size = 0.8

train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
Run Code Online (Sandbox Code Playgroud)

full_dataset 是类型的对象torch.utils.data.dataloader.DataLoader。我可以用这样的循环来迭代它:

for batch_idx, (data, target) in enumerate(full_dataset):
    print(batch_idx)
Run Code Online (Sandbox Code Playgroud)

train_dataset类型的对象torch.utils.data.dataset.Subset。如果我尝试循环它,我会得到:

TypeError“DataLoader”对象不可下标:

for batch_idx, (data, target) in enumerate(train_dataset):
    print(batch_idx)
Run Code Online (Sandbox Code Playgroud)

我怎样才能循环它?我对 Python 比较陌生。

谢谢!

python loops pytorch

5
推荐指数
1
解决办法
5100
查看次数

标签 统计

loops ×1

python ×1

pytorch ×1