我目前正在加载一个包含人工智能训练数据的文件夹。子文件夹代表标签名称以及内部相应的图像。使用 pyTorch 的 ImageFolder 加载器可以很好地实现这一点。
def load_dataset():
data_path = 'C:/example_folder/'
train_dataset_manual = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader_manual = torch.utils.data.DataLoader(
train_dataset_manual,
batch_size=1,
num_workers=0,
shuffle=True
)
return train_loader_manual
full_dataset = load_dataset()
Run Code Online (Sandbox Code Playgroud)
现在我想将此数据集分为训练数据集和测试数据集。我为此使用 random_split 函数:
training_data_size = 0.8
train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
Run Code Online (Sandbox Code Playgroud)
full_dataset 是类型的对象torch.utils.data.dataloader.DataLoader。我可以用这样的循环来迭代它:
for batch_idx, (data, target) in enumerate(full_dataset):
print(batch_idx)
Run Code Online (Sandbox Code Playgroud)
是train_dataset类型的对象torch.utils.data.dataset.Subset。如果我尝试循环它,我会得到:
TypeError“DataLoader”对象不可下标:
for batch_idx, (data, target) in enumerate(train_dataset):
print(batch_idx)
Run Code Online (Sandbox Code Playgroud)
我怎样才能循环它?我对 Python 比较陌生。
谢谢!