我想显示一个图像。它使用a加载ImageLoader并存储在PyTorch中Tensor。
当我尝试通过显示它时plt.imshow(image),我得到:
TypeError: Invalid dimensions for image data
Run Code Online (Sandbox Code Playgroud)
该.shape张量是:
torch.Size([3, 244, 244])
Run Code Online (Sandbox Code Playgroud)
如何显示此PyTorch张量中包含的图像?
Tom*_*ale 14
给定一个Tensor代表图像的图像,请使用.permute():
plt.imshow( tensor_image.permute(1, 2, 0) )
Run Code Online (Sandbox Code Playgroud)
注意:permute不会复制或分配内存,也不会。 from_numpy()
处理图像数据的 PyTorch 模块需要格式为C \xc3\x97 H \xc3\x97 W的张量。1
\n而 PILLow 和 Matplotlib 期望图像数组的格式为H \xc3\x97 W \xc3\x97 C。2
您可以使用 TorchVision 变换轻松地将张量与此格式相互转换:
\nfrom torchvision.transforms import functional as F\n\nF.to_pil_image(image_tensor)\nRun Code Online (Sandbox Code Playgroud)\n或者直接排列轴:
\nimage_tensor.permute(1,2,0)\nRun Code Online (Sandbox Code Playgroud)\n\n\n处理图像数据的 PyTorch 模块需要将张量布局为C \xc3\x97 H \xc3\x97 W:分别是通道、高度和宽度。
\n
\n\n请注意我们如何将
\n\npermute轴的顺序从C \xc3\x97 H \xc3\x97 W更改为H \xc3\x97 W \xc3\x97 C以匹配 Matplotlib 的期望。
如您所见,matplotlib即使不转换为numpy数组也可以正常工作。但是PyTorch张量(“图像张量”)是第一个通道,因此要与它们一起使用,matplotlib您需要对其进行重塑:
码:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Run Code Online (Sandbox Code Playgroud)
输出:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
Run Code Online (Sandbox Code Playgroud)
鉴于图像按描述加载并存储在变量中image:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively
Run Code Online (Sandbox Code Playgroud)
或者像Soumith 建议的那样:
Run Code Online (Sandbox Code Playgroud)def show(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
| 归档时间: |
|
| 查看次数: |
12327 次 |
| 最近记录: |