img 应该是 PIL Image。得到 <class 'torch.Tensor'>

Ham*_*man 8 python pytorch

我正在尝试遍历加载程序以检查它是否正常工作,但是给出了以下错误:

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

我试着将两者transforms.ToTensor()transforms.ToPILImage()和它给我一个错误,要求相反。即,使用ToPILImage(),它将要求张量,反之亦然。

# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np

data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224), 
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)

#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32, 
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))
Run Code Online (Sandbox Code Playgroud)

plt.imshow(images[0])如果运行正常,它应该允许我在运行后简单地查看图像。

Anu*_*ngh 18

transforms.RandomHorizontalFlip()工作PIL.Images,不是torch.Tensor。在上面的代码中,您在transforms.ToTensor()之前应用transforms.RandomHorizontalFlip(),这会导致张量。

但是,根据此处的官方 pytorch 文档,

transforms.RandomHorizo​​ntalFlip() 以给定的概率随机水平翻转给定的 PIL 图像。

因此,只需更改上面代码中转换的顺序,如下所示:

train_transforms = transforms.Compose([transforms.Resize(255), 
                                       transforms.CenterCrop(224),  
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(), 
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 
Run Code Online (Sandbox Code Playgroud)

  • 尽管您当时可能是对的,但对于任何未来的读者来说,情况已经改变,现在也可以提供张量: &gt; 以给定的概率随机水平翻转给定的图像。图像可以是 PIL 图像或火炬张量” (3认同)

Bip*_*Das 6

只需添加transforms.ToPILImage()转换为 pil 图像,然后它就可以工作,例如:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
Run Code Online (Sandbox Code Playgroud)