PyTorch 的“ToPILImage”问题

roc*_*ves 5 python numpy pytorch torchvision

为什么这不起作用?

import torchvision.transforms.functional as tf
from torchvision import transforms
pic = np.random.randint(0, 255+1, size=28*28).reshape(28, 28)
pic = pic.astype(int)
plt.imshow(pic)
t = transforms.ToPILImage()
t(pic.reshape(28, 28, 1))
# tf.to_pil_image(pic.reshape(28, 28, 1))
Run Code Online (Sandbox Code Playgroud)

matplotlib 绘制了一张漂亮的随机图片,但无论我为我的 NumPy ndarray 选择什么数据类型,都不能to_pil_imageToPILImage按预期工作。

文档有这样的说法:

将张量 ...或形状为 H x W x C 的 numpy ndarray转换为PIL 图像,同时保留值范围。... 如果输入有 1 个通道, mode 则由数据类型(即 intfloatshort)决定

除了“short”之外,这些数据类型都不起作用。

其他一切都会导致:

TypeError: Input type int64/float64 is not supported
Run Code Online (Sandbox Code Playgroud)

从扔torchvision/transforms/functional.pyto_pil_image()

此外,即使short数据类型适用于我首先提供的独立代码片段,但在transform.Compose()Dataset对象的调用中使用时它会崩溃__getitem__

choices = transforms.RandomChoice([transforms.RandomAffine(30),
                                   transforms.RandomPerspective()])

transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomApply([choices], 0.5),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

trainset = MNIST('data/train.csv', transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)


Run Code Online (Sandbox Code Playgroud)
RuntimeError: DataLoader worker (pid 12917) is killed by signal: Floating point exception.
RuntimeError: DataLoader worker (pid(s) 12917) exited unexpectedly
Run Code Online (Sandbox Code Playgroud)

ita*_*ter 0

查看 的源代码to_pil_image,您可以看到仅np.{uint8, int16, uint32, float32}支持 numpy 类型的数组。

尝试将图片投射到np.uint8

pic = pic.astype(np.uint8)
Run Code Online (Sandbox Code Playgroud)

那应该对你有用