如何将Pytorch数据加载器转换为numpy数组以使用matplotlib显示图像数据?

Aer*_*yes 3 python numpy machine-learning conv-neural-network pytorch

我是Pytorch的新手。在开始使用CNN进行训练之前,我一直在尝试学习如何查看输入的图像。我很难将图像更改为可与matplotlib一起使用的形式。

到目前为止,我已经尝试过了:

from multiprocessing import freeze_support

import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np
import PIL

num_classes = 5
batch_size = 100
num_of_workers = 5

DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'

trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToPImage(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg)
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))

def main():
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = dataiter.next()

    # show images
    imshow(images)
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

if __name__ == "__main__":
    main()
Run Code Online (Sandbox Code Playgroud)

但是,这引发了错误:

  [[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
    0.27450982]
   [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
    0.40784314]
   [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
    0.34509805]
   ...
   [0.2784314  0.21960783 0.2352941  ... 0.5803922  0.46666667
    0.25882354]
   [0.26666668 0.16862744 0.23137254 ... 0.2901961  0.29803923
    0.2509804 ]
   [0.30980393 0.39607844 0.28627452 ... 0.1490196  0.10588235
    0.19607842]]

  [[0.2352941  0.06274509 0.15686274 ... 0.09411764 0.3019608
    0.19215685]
   [0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
    0.3019608 ]
   [0.20392156 0.13333333 0.1607843  ... 0.16862744 0.2117647
    0.22745097]
   ...
   [0.18039215 0.16862744 0.1490196  ... 0.45882353 0.36078432
    0.16470587]
   [0.1607843  0.10588235 0.14117646 ... 0.2117647  0.18039215
    0.10980392]
   [0.18039215 0.3019608  0.2117647  ... 0.11372548 0.06274509
    0.04705882]]]


 ...


 [[[0.8980392  0.8784314  0.8509804  ... 0.627451   0.627451
    0.627451  ]
   [0.8509804  0.8235294  0.7921569  ... 0.54901963 0.5568628
    0.56078434]
   [0.7921569  0.7529412  0.7176471  ... 0.47058824 0.48235294
    0.49411765]
   ...
   [0.3764706  0.38431373 0.3764706  ... 0.4509804  0.43137255
    0.39607844]
   [0.38431373 0.39607844 0.3882353  ... 0.4509804  0.43137255
    0.39607844]
   [0.3882353  0.4        0.39607844 ... 0.44313726 0.42352942
    0.39215687]]

  [[0.9254902  0.90588236 0.88235295 ... 0.60784316 0.6
    0.5921569 ]
   [0.88235295 0.85490197 0.8235294  ... 0.5411765  0.5372549
    0.53333336]
   [0.8235294  0.7882353  0.75686276 ... 0.47058824 0.47058824
    0.47058824]
   ...
   [0.50980395 0.5176471  0.5137255  ... 0.58431375 0.5647059
    0.53333336]
   [0.5137255  0.53333336 0.5254902  ... 0.58431375 0.5686275
    0.53333336]
   [0.5176471  0.53333336 0.5294118  ... 0.5764706  0.56078434
    0.5294118 ]]

  [[0.95686275 0.9372549  0.90588236 ... 0.18823528 0.19999999
    0.20784312]
   [0.9098039  0.8784314  0.8352941  ... 0.1607843  0.17254901
    0.18039215]
   [0.84313726 0.7921569  0.7490196  ... 0.1372549  0.14509803
    0.15294117]
   ...
   [0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
    0.02745098]
   [0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
    0.03529412]
   [0.05098039 0.0745098  0.07843137 ... 0.12549019 0.10196078
    0.04705882]]]


 [[[0.30588236 0.28627452 0.24313724 ... 0.2901961  0.26666668
    0.21568626]
   [0.8156863  0.6666667  0.5921569  ... 0.18039215 0.23921567
    0.21568626]
   [0.9019608  0.83137256 0.85490197 ... 0.21960783 0.36862746
    0.23921567]
   ...
   [0.7058824  0.83137256 0.85490197 ... 0.2627451  0.24313724
    0.20784312]
   [0.7137255  0.84313726 0.84705883 ... 0.26666668 0.29803923
    0.21568626]
   [0.7254902  0.8235294  0.8392157  ... 0.2509804  0.27058825
    0.2352941 ]]

  [[0.24705881 0.22745097 0.19215685 ... 0.2784314  0.25490198
    0.19607842]
   [0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
    0.20392156]
   [0.5921569  0.4509804  0.49803922 ... 0.20784312 0.3764706
    0.2352941 ]
   ...
   [0.42352942 0.4627451  0.42352942 ... 0.23921567 0.23137254
    0.19999999]
   [0.45882353 0.5176471  0.35686275 ... 0.23921567 0.26666668
    0.19607842]
   [0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
    0.21568626]]

  [[0.23137254 0.20784312 0.1490196  ... 0.30588236 0.28627452
    0.19607842]
   [0.61960787 0.3764706  0.26666668 ... 0.16470587 0.24313724
    0.21568626]
   [0.57254905 0.43137255 0.48235294 ... 0.2235294  0.40392157
    0.25882354]
   ...
   [0.4        0.42352942 0.37254903 ... 0.25490198 0.24705881
    0.21568626]
   [0.43137255 0.4509804  0.29411766 ... 0.25882354 0.28235295
    0.20392156]
   [0.38431373 0.3529412  0.25490198 ... 0.2352941  0.25490198
    0.23137254]]]


 [[[0.06274509 0.09019607 0.11372548 ... 0.5803922  0.5176471
    0.59607846]
   [0.09411764 0.14509803 0.1372549  ... 0.5294118  0.49803922
    0.5058824 ]
   [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
    0.38431373]
   ...
   [0.15294117 0.12941176 0.1607843  ... 0.85882354 0.8509804
    0.80784315]
   [0.14509803 0.10588235 0.1607843  ... 0.8666667  0.85882354
    0.8       ]
   [0.1490196  0.10588235 0.16470587 ... 0.827451   0.8156863
    0.7921569 ]]

  [[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
    0.6039216 ]
   [0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
    0.5372549 ]
   [0.03921568 0.0745098  0.09803921 ... 0.48235294 0.4392157
    0.4117647 ]
   ...
   [0.2117647  0.14509803 0.2784314  ... 0.43137255 0.3529412
    0.34117648]
   [0.2235294  0.11372548 0.2509804  ... 0.4509804  0.39607844
    0.2509804 ]
   [0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
    0.3254902 ]]

  [[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
    0.45490196]
   [0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
    0.3882353 ]
   [0.03921568 0.06666666 0.0862745  ... 0.3764706  0.33333334
    0.28235295]
   ...
   [0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
    0.09411764]
   [0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
    0.05490196]
   [0.12156862 0.11372548 0.17254901 ... 0.2352941  0.17254901
    0.14117646]]]]
Traceback (most recent call last):
  File "image_loader.py", line 51, in <module>
    main()
  File "image_loader.py", line 46, in main
    imshow(images)
  File "image_loader.py", line 38, in imshow
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose
    return _wrapfunc(a, 'transpose', axes)
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc
    return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose
Run Code Online (Sandbox Code Playgroud)

我试图打印出数组以获得尺寸,但是我不知道该怎么做。这很令人困惑。

这是我的直接问题:在使用DataLoader对象中的张量进行训练之前,如何查看输入图像?

hkc*_*rex 6

首先,dataloader输出4维张量- [batch, channel, height, width]。Matplotlib和其他图像处理库经常需要[height, width, channel]。您使用转置是正确的,只是使用方式不正确。

您的图像很多,images因此首先您需要选择一个图像(或编写一个for循环以保存所有图像)。这将很简单images[i],通常我会使用i=0

然后,转置应该将现在的[channel, height, width]张量转换为一个张量[height, width, channel]。为此np.transpose(image.numpy(), (1, 2, 0)),请非常像您一样使用。

放在一起,你应该有

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))
Run Code Online (Sandbox Code Playgroud)

有时您需要根据用例进行调用.detach()(将这部分与计算图分开)和.cpu()(将数据从GPU传输到CPU),具体取决于

plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))
Run Code Online (Sandbox Code Playgroud)