为什么使用matplotlib无法正确显示CIFAR-10图像?

Sid*_*rth 5 python image machine-learning matplotlib computer-vision

从训练集中我拍了一张大小为(3,32,32)的图像('img').我用过plt.imshow(img.T).图像不清晰.现在我必须对图像('img')进行更改,以使其更清晰可见.谢谢.

这是我得到的形象

use*_*359 18

下面打印5X5网格随机Cifar10图像.它并不模糊,但也不完美.欢迎任何建议.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from six.moves import cPickle 

f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='latin1')
f.close()
X = datadict["data"] 
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)

#Visualizing CIFAR 10
fig, axes1 = plt.subplots(5,5,figsize=(3,3))
for j in range(5):
    for k in range(5):
        i = np.random.choice(range(len(X)))
        axes1[j][k].set_axis_off()
        axes1[j][k].imshow(X[i:i+1][0])
Run Code Online (Sandbox Code Playgroud)

  • 这是很好的信息,但 astype("float") 会将图像显示为负数,将其设置为 astype("uint8") 是正常的。 (2认同)

Sou*_*nak 6

该文件读取cifar10 数据集并使用matplotlib.

import _pickle as pickle
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt

cifar10 = "./cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
parser.add_argument("-i", "--image", type=int, default=0, 
                    help="Index of the image in cifar10. In range [0, 49999]")
args = parser.parse_args()


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def cifar10_plot(data, meta, im_idx=0):
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    img = np.dstack((im_r, im_g, im_b))

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         

    plt.imshow(img) 
    plt.show()


def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(cifar10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)


if __name__ == "__main__":
    main()
Run Code Online (Sandbox Code Playgroud)


Ale*_*ton 5

当您要显示图像时,请确保不要对数据集进行标准化。

例子 :

装载机...

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                        #  transforms.Normalize(
                        #      (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                     ])),
    batch_size=64, shuffle=True)
Run Code Online (Sandbox Code Playgroud)

显示图像的代码...

img = next(iter(train_loader))[0][0]
plt.imshow(transforms.ToPILImage()(img))
Run Code Online (Sandbox Code Playgroud)

归一化

归一化

没有标准化

未标准化