如何在PyTorch中显示单个图像?

Tom*_*ale 8 python pytorch

我想显示一个图像。它使用a加载ImageLoader并存储在PyTorch中Tensor

当我尝试通过显示它时plt.imshow(image),我得到:

TypeError: Invalid dimensions for image data
Run Code Online (Sandbox Code Playgroud)

.shape张量是:

torch.Size([3, 244, 244])
Run Code Online (Sandbox Code Playgroud)

如何显示此PyTorch张量中包含的图像?

Tom*_*ale 14

给定一个Tensor代表图像的图像,请使用.permute()

plt.imshow(  tensor_image.permute(1, 2, 0)  )
Run Code Online (Sandbox Code Playgroud)

注意:permute不会复制或分配内存也不会。 from_numpy()

  • @DevashishPrasad 问题是 `reshape([224,224,3])` 不做与 `permute(1, 2, 0)` 相同的事情。“permute”函数类似于转置矩阵,其中行变成列,列变成行。“reshape”函数做了一些完全不相关的事情,我不知道如何简洁地描述。简而言之,“reshape”是错误的函数。 (2认同)

uke*_*emi 9

处理图像数据的 PyTorch 模块需要格式为C \xc3\x97 H \xc3\x97 W的张量。1
\n而 PILLow 和 Matplotlib 期望图像数组的格式为H \xc3\x97 W \xc3\x97 C2

\n

您可以使用 TorchVision 变换轻松地将张量与此格式相互转换:

\n
from torchvision.transforms import functional as F\n\nF.to_pil_image(image_tensor)\n
Run Code Online (Sandbox Code Playgroud)\n

或者直接排列轴:

\n
image_tensor.permute(1,2,0)\n
Run Code Online (Sandbox Code Playgroud)\n
\n \n
    \n
  1. \n
    \n

    处理图像数据的 PyTorch 模块需要将张量布局为C \xc3\x97 H \xc3\x97 W:分别是通道、高度和宽度。

    \n
    \n
  2. \n
  3. \n
    \n

    请注意我们如何将permute轴的顺序从C \xc3\x97 H \xc3\x97 W更改为H \xc3\x97 W \xc3\x97 C以匹配 Matplotlib 的期望。

    \n\n
    \n
  4. \n
\n
\n


trs*_*chn 7

如您所见,matplotlib即使不转换为numpy数组也可以正常工作。但是PyTorch张量(“图像张量”)是第一个通道,因此要与它们一起使用,matplotlib您需要对其进行重塑:

码:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()
Run Code Online (Sandbox Code Playgroud)

输出:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
Run Code Online (Sandbox Code Playgroud)


Tom*_*ale 5

鉴于图像按描述加载并存储在变量中image

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively
Run Code Online (Sandbox Code Playgroud)

或者像Soumith 建议的那样

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
Run Code Online (Sandbox Code Playgroud)