如何在pytorch中从图像中提取补丁?

Joh*_*all 4 python image-processing pytorch

我想从补丁大小为 128、步幅为 32 的图像中提取图像补丁,所以我有这段代码,但它给了我一个错误:

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
Run Code Online (Sandbox Code Playgroud)

我得到的错误是:

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128
Run Code Online (Sandbox Code Playgroud)

这是迄今为止我找到的唯一方法。但它给了我这个错误

Mic*_*ngo 11

你的尺寸x[1, 3, height, width]. 调用x.unfold(1, size, stride)尝试从维度 1 创建大小为 128 的切片,维度 1 的大小为 3,因此它太小,无法创建任何切片。

您不想在维度 1 上创建切片,因为这些是图像的通道(在本例中为 RGB),并且需要将它们保留为所有补丁的原样。补丁仅在图像的高度和宽度上创建。

patches = x.unfold(2, size, stride).unfold(3, size, stride)
Run Code Online (Sandbox Code Playgroud)

所得张量的大小为[1, 3, num_vertical_slices, num_horizontal_slices, 128, 128]。您可以重塑它以组合切片以获得补丁列表,即大小为[1, 3, num_patches, 128, 128]

patches = patches.reshape(1, 3, -1, size, size)
Run Code Online (Sandbox Code Playgroud)