尝试在 Pytorch 中加载自定义数据集

dsi*_*000 4 python machine-learning dataset computer-vision pytorch

我刚开始使用 PyTorch,不幸的是,在使用我自己的训练/测试图像数据集进行自定义算法时有点困惑。首先,我正在制作一个小型的“hello world”式卷积衬衫/袜子/裤子分类网络。我只加载了一些图像,只是确保 PyTorch 可以加载它们并将它们正确地转换为 32x32 可用图像。我的 ImageFolder 设置如下:

imgs/socks/ sockimages .jpeg
imgs/pants/ pantsimages .jpeg
imgs/shirt/ shirtimages .jpeg

以及我的测试图像文件夹的类似设置。根据我目前的知识,PyTorch 内置的图像加载器应该从训练/测试图像中的子文件夹名称中读取标签。但是,我TypeError抱怨我的迭代器不可迭代。这是我的代码和错误:

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Scale((32,32)),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = dset.ImageFolder(root="imgs",transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True,         num_workers=2)

testset = dset.ImageFolder(root='tests',transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=True,     num_workers=2)

classes=('shirt','pants','sock')

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
Run Code Online (Sandbox Code Playgroud)

错误:

TypeError: 'builtin_function_or_method' object is not iterable
Run Code Online (Sandbox Code Playgroud)

它说它引用了包含的行dataiter.next(),这意味着编译器认为我不能迭代dataiter

请帮忙!提前致谢,

-David Sillman,PyTorch 新手

Man*_*nas 5

我认为错误是因为在transform.Compose你首先做的时候.ToTensor(),你应该做.Scale()Pytorch张量PIL Images上有转换,不能互换。阅读它说的文档

class torchvision.transforms.Scale(size, interpolation=2) [...] 将输入 PIL.Image 重新缩放到给定的大小。

当您在缩放之前将该图像更改为 Pytorch 张量从而使其崩溃时。

应该改为:

transform = transforms.Compose(
                   [transforms.Scale((32,32)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
Run Code Online (Sandbox Code Playgroud)

PIL Image在张量上应用转换时,您将收到此错误。