Python matplotlib,图像数据的无效形状

Hek*_*kes 3 python matplotlib python-3.x pytorch

目前我有这个代码来显示三个图像:

imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')
Run Code Online (Sandbox Code Playgroud)

它工作正常。但是我试图将它们全部放在一行而不是列中。

这是我尝试过的代码:

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)
Run Code Online (Sandbox Code Playgroud)

它抛出

类型错误:无法将 CUDA 张量转换为 numpy。首先使用 Tensor.cpu() 将张量复制到主机内存。

如果我做

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())
Run Code Online (Sandbox Code Playgroud)

它抛出

TypeError: Invalid shape (1, 3, 128, 128) for image data

我应该如何解决这个问题,或者有更简单的方法来实现它?

A. *_*man 7

matplotlib 函数 'imshow' 将 3 通道图片作为 (h, w, 3) 获取,如您在文档中所见。

似乎您传递了图像的三个通道(第二维)的“批次”单个图像(第一维)(h 和 w 是第三和第四维)。

您需要重塑或查看您的图像(转换为 cpu 后,尝试使用:

image1.squeeze().permute(1,2,0)
Run Code Online (Sandbox Code Playgroud)

结果将是所需形状(128、128、3)的图像。

挤压()函数将删除第一维。premute() 函数将调换维度,其中第一个将移到第三个位置,另外两个将移到开头。

另外,请查看此处以进一步讨论 GPU 和 CPU 问题: 链接

希望有帮助。