Aer*_*yes 3 python numpy machine-learning conv-neural-network pytorch
我是Pytorch的新手。在开始使用CNN进行训练之前,我一直在尝试学习如何查看输入的图像。我很难将图像更改为可与matplotlib一起使用的形式。
到目前为止,我已经尝试过了:
from multiprocessing import freeze_support
import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
import PIL
num_classes = 5
batch_size = 100
num_of_workers = 5
DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'
trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToPImage(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
print(npimg)
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
def main():
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(images)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
if __name__ == "__main__":
main()
Run Code Online (Sandbox Code Playgroud)
但是,这引发了错误:
[[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
0.27450982]
[0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
0.40784314]
[0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
0.34509805]
...
[0.2784314 0.21960783 0.2352941 ... 0.5803922 0.46666667
0.25882354]
[0.26666668 0.16862744 0.23137254 ... 0.2901961 0.29803923
0.2509804 ]
[0.30980393 0.39607844 0.28627452 ... 0.1490196 0.10588235
0.19607842]]
[[0.2352941 0.06274509 0.15686274 ... 0.09411764 0.3019608
0.19215685]
[0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
0.3019608 ]
[0.20392156 0.13333333 0.1607843 ... 0.16862744 0.2117647
0.22745097]
...
[0.18039215 0.16862744 0.1490196 ... 0.45882353 0.36078432
0.16470587]
[0.1607843 0.10588235 0.14117646 ... 0.2117647 0.18039215
0.10980392]
[0.18039215 0.3019608 0.2117647 ... 0.11372548 0.06274509
0.04705882]]]
...
[[[0.8980392 0.8784314 0.8509804 ... 0.627451 0.627451
0.627451 ]
[0.8509804 0.8235294 0.7921569 ... 0.54901963 0.5568628
0.56078434]
[0.7921569 0.7529412 0.7176471 ... 0.47058824 0.48235294
0.49411765]
...
[0.3764706 0.38431373 0.3764706 ... 0.4509804 0.43137255
0.39607844]
[0.38431373 0.39607844 0.3882353 ... 0.4509804 0.43137255
0.39607844]
[0.3882353 0.4 0.39607844 ... 0.44313726 0.42352942
0.39215687]]
[[0.9254902 0.90588236 0.88235295 ... 0.60784316 0.6
0.5921569 ]
[0.88235295 0.85490197 0.8235294 ... 0.5411765 0.5372549
0.53333336]
[0.8235294 0.7882353 0.75686276 ... 0.47058824 0.47058824
0.47058824]
...
[0.50980395 0.5176471 0.5137255 ... 0.58431375 0.5647059
0.53333336]
[0.5137255 0.53333336 0.5254902 ... 0.58431375 0.5686275
0.53333336]
[0.5176471 0.53333336 0.5294118 ... 0.5764706 0.56078434
0.5294118 ]]
[[0.95686275 0.9372549 0.90588236 ... 0.18823528 0.19999999
0.20784312]
[0.9098039 0.8784314 0.8352941 ... 0.1607843 0.17254901
0.18039215]
[0.84313726 0.7921569 0.7490196 ... 0.1372549 0.14509803
0.15294117]
...
[0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
0.02745098]
[0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
0.03529412]
[0.05098039 0.0745098 0.07843137 ... 0.12549019 0.10196078
0.04705882]]]
[[[0.30588236 0.28627452 0.24313724 ... 0.2901961 0.26666668
0.21568626]
[0.8156863 0.6666667 0.5921569 ... 0.18039215 0.23921567
0.21568626]
[0.9019608 0.83137256 0.85490197 ... 0.21960783 0.36862746
0.23921567]
...
[0.7058824 0.83137256 0.85490197 ... 0.2627451 0.24313724
0.20784312]
[0.7137255 0.84313726 0.84705883 ... 0.26666668 0.29803923
0.21568626]
[0.7254902 0.8235294 0.8392157 ... 0.2509804 0.27058825
0.2352941 ]]
[[0.24705881 0.22745097 0.19215685 ... 0.2784314 0.25490198
0.19607842]
[0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
0.20392156]
[0.5921569 0.4509804 0.49803922 ... 0.20784312 0.3764706
0.2352941 ]
...
[0.42352942 0.4627451 0.42352942 ... 0.23921567 0.23137254
0.19999999]
[0.45882353 0.5176471 0.35686275 ... 0.23921567 0.26666668
0.19607842]
[0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
0.21568626]]
[[0.23137254 0.20784312 0.1490196 ... 0.30588236 0.28627452
0.19607842]
[0.61960787 0.3764706 0.26666668 ... 0.16470587 0.24313724
0.21568626]
[0.57254905 0.43137255 0.48235294 ... 0.2235294 0.40392157
0.25882354]
...
[0.4 0.42352942 0.37254903 ... 0.25490198 0.24705881
0.21568626]
[0.43137255 0.4509804 0.29411766 ... 0.25882354 0.28235295
0.20392156]
[0.38431373 0.3529412 0.25490198 ... 0.2352941 0.25490198
0.23137254]]]
[[[0.06274509 0.09019607 0.11372548 ... 0.5803922 0.5176471
0.59607846]
[0.09411764 0.14509803 0.1372549 ... 0.5294118 0.49803922
0.5058824 ]
[0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
0.38431373]
...
[0.15294117 0.12941176 0.1607843 ... 0.85882354 0.8509804
0.80784315]
[0.14509803 0.10588235 0.1607843 ... 0.8666667 0.85882354
0.8 ]
[0.1490196 0.10588235 0.16470587 ... 0.827451 0.8156863
0.7921569 ]]
[[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
0.6039216 ]
[0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
0.5372549 ]
[0.03921568 0.0745098 0.09803921 ... 0.48235294 0.4392157
0.4117647 ]
...
[0.2117647 0.14509803 0.2784314 ... 0.43137255 0.3529412
0.34117648]
[0.2235294 0.11372548 0.2509804 ... 0.4509804 0.39607844
0.2509804 ]
[0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
0.3254902 ]]
[[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
0.45490196]
[0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
0.3882353 ]
[0.03921568 0.06666666 0.0862745 ... 0.3764706 0.33333334
0.28235295]
...
[0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
0.09411764]
[0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
0.05490196]
[0.12156862 0.11372548 0.17254901 ... 0.2352941 0.17254901
0.14117646]]]]
Traceback (most recent call last):
File "image_loader.py", line 51, in <module>
main()
File "image_loader.py", line 46, in main
imshow(images)
File "image_loader.py", line 38, in imshow
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose
return _wrapfunc(a, 'transpose', axes)
File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose
Run Code Online (Sandbox Code Playgroud)
我试图打印出数组以获得尺寸,但是我不知道该怎么做。这很令人困惑。
这是我的直接问题:在使用DataLoader对象中的张量进行训练之前,如何查看输入图像?
首先,dataloader输出4维张量- [batch, channel, height, width]。Matplotlib和其他图像处理库经常需要[height, width, channel]。您使用转置是正确的,只是使用方式不正确。
您的图像很多,images因此首先您需要选择一个图像(或编写一个for循环以保存所有图像)。这将很简单images[i],通常我会使用i=0。
然后,转置应该将现在的[channel, height, width]张量转换为一个张量[height, width, channel]。为此np.transpose(image.numpy(), (1, 2, 0)),请非常像您一样使用。
放在一起,你应该有
plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))
Run Code Online (Sandbox Code Playgroud)
有时您需要根据用例进行调用.detach()(将这部分与计算图分开)和.cpu()(将数据从GPU传输到CPU),具体取决于
plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2283 次 |
| 最近记录: |