“DataLoader”对象不支持索引

Far*_*han 3 python computer-vision imagenet pytorch

我已经通过这个 pytorch api 通过设置 download=True 下载了 ImageNet 数据集。但我无法遍历数据加载器。

错误说“'DataLoader' 对象不支持索引”

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
Run Code Online (Sandbox Code Playgroud)

我尝试了一种简单的方法,我只是尝试运行以下命令,

trainloader[0]
Run Code Online (Sandbox Code Playgroud)

在根目录中,模式是

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg
Run Code Online (Sandbox Code Playgroud)

官网上的文档没有说别的。https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet

我究竟做错了什么 ?

Szy*_*zke 5

嗯,答案很简单(除了另一个答案中提到的错误)。

DataLoader没有__getitem__方法在源代码中自己查看)。

它用于对数据(或成批数据)进行迭代,而不是随机访问。如果你想访问你应该使用的特定元素torch.utils.data.Dataset,在你的情况下:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]
Run Code Online (Sandbox Code Playgroud)

获得一批

如果你想得到一个批次,你可以迭代它并在之后中断:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break
Run Code Online (Sandbox Code Playgroud)

DataLoader以默认或指定的方式创建随机索引(参见samplers),因此没有,__getitem__因为它对这个对象没有意义。

您也可以继承DataLoader并创建您自己的__getitem__函数,做您想做的事(虽然更复杂)。

完整示例

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break
Run Code Online (Sandbox Code Playgroud)

上面应该打印第一批里面的东西。


Far*_*han 5

解决方案

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  


# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)


for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break
Run Code Online (Sandbox Code Playgroud)