是否有在PyTorch中提取图像补丁的功能?

Gab*_*oni 9 python tensorflow pytorch

鉴于一批图像,我想提取所有可能的图像补丁,类似于卷积.在TensorFlow中,我们可以tf.extract_image_patches用来实现这一目标.PyTorch中有相同的功能吗?

谢谢.

Dal*_*yaG 7

也许这个代码示例将有助于理解如何使用unfold,受到@gasoon 链接的这个线程的启发,但更详细一些:

batch_size, n_channels, n_rows, n_cols = 32, 3, 64, 64
kernel_h, kernel_w = 7, 9
step = 5

x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

# unfold(dimension, size, step)
windows = x.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(2, 3, 0, 1, 4, 5).reshape(-1, n_channels, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([4608, 3, 7, 9]) = [n_windows, n_channels, krenel_h, kernel_w]
Run Code Online (Sandbox Code Playgroud)


gas*_*oon 5

不幸的是,可能没有直接的方法来实现您的目标。
但是 Tensor.unfold 函数可能是一个解决方案。
https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2
这个网站可能对你有帮助。