Pytorch 中是否有类似于 OpenCV 中的 cv2.dilate 的“张量”操作或函数?

lar*_*ang 4 python opencv pytorch

我通过网络制作了几个面具。这些掩码存储在一个torch.tensor变量中。我想cv2.dilatetensor.

我知道有一种方法可以使用循环将其转换tensornumpy.ndarray然后应用于cv2.dilate每个通道for。但是由于大约有 32 个通道,这种方法可能会减慢网络中的前向操作。

Ard*_*iya 13

我认为扩张本质上是火炬中的 conv2d 操作。看下面的代码

import cv2
import numpy as np
import torch

im = np.array([ [0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0],
                [0, 1, 1, 0, 0],
                [0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0] ], dtype=np.float32)
kernel = np.array([ [1, 1, 1],
                    [1, 1, 1],
                    [1, 1, 1] ], dtype=np.float32)
print(cv2.dilate(im, kernel))
# [[1. 1. 1. 0. 0.]
#  [1. 1. 1. 1. 0.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [0. 0. 1. 1. 1.]]
im_tensor = torch.Tensor(np.expand_dims(np.expand_dims(im, 0), 0)) # size:(1, 1, 5, 5)
kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3)
torch_result = torch.clamp(torch.nn.functional.conv2d(im_tensor, kernel_tensor, padding=(1, 1)), 0, 1)
print(torch_result)
# tensor([[[[1., 1., 1., 0., 0.],
#           [1., 1., 1., 1., 0.],
#           [1., 1., 1., 1., 1.],
#           [1., 1., 1., 1., 1.],
#           [0., 0., 1., 1., 1.]]]])
Run Code Online (Sandbox Code Playgroud)

  • 我坚持认为 dilate 和 conv2d 不是同一个操作! (3认同)

Man*_*nza 7

编辑:

我最近与kornia合作,现在形态学操作按预期工作。

编辑:

我创建了一个库来做到这一点;该库称为nnMorpho,可以通过 pip install nnMorpho. 我使用的原理如下所述(:使用 PyTorch 中的展开函数)。目前该库处于早期阶段(仅实现了基本操作),但我将尝试更新它以包含更多种类的操作和参数。

Dilation 和 convd2d 不一样

Dilation 和 convd2d 根本不一样:粗略地说,Convd2d 执行线性滤波器(这意味着它在像素周围进行深思熟虑的求和),而 dilation 执行非线性滤波器(获取像素周围的最大值)。

在 PyTorch 中进行形态学分析的一种方法

PyTorch 有一种方法可以进行数学形态学运算。处理膨胀和腐蚀时面临的主要问题是,您必须考虑每个像素的邻域来计算最大值(如果处理灰度结构元素,则可能需要考虑总和和差)。这个问题可以通过PyTorch 中的Expand函数解决;它目前仅支持批量的类似图像的张量(即:尺寸为(B,C,H,W)的4D张量),但这对于您的需求来说应该不是问题。剩下的只是正常操作。

我加入了执行膨胀(侵蚀类似)的代码和示例:

import numpy as np
import torch
from torch.nn import functional as f
from scipy.ndimage import grey_dilation as dilation_scipy
import matplotlib.pyplot as plt


# Definition of the dilation using PyTorch
def dilation_pytorch(image, strel, origin=(0, 0), border_value=0):
    # first pad the image to have correct unfolding; here is where the origins is used
    image_pad = f.pad(image, [origin[0], strel.shape[0] - origin[0] - 1, origin[1], strel.shape[1] - origin[1] - 1], mode='constant', value=border_value)
    # Unfold the image to be able to perform operation on neighborhoods
    image_unfold = f.unfold(image_pad.unsqueeze(0).unsqueeze(0), kernel_size=strel.shape)
    # Flatten the structural element since its two dimensions have been flatten when unfolding
    strel_flatten = torch.flatten(strel).unsqueeze(0).unsqueeze(-1)
    # Perform the greyscale operation; sum would be replaced by rest if you want erosion
    sums = image_unfold + strel_flatten
    # Take maximum over the neighborhood
    result, _ = sums.max(dim=1)
    # Reshape the image to recover initial shape
    return torch.reshape(result, image.shape)


# Test image
image = np.zeros((7, 7), dtype=int)
image[2:5, 2:5] = 1
image[4, 4] = 2
image[2, 3] = 3

plt.figure()
plt.imshow(image, cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Original image')

# Structural element square 3x3
strel = np.ones((3, 3))

# Origin of the structural element
origin = (1, 1)

# Scipy
dilated_image_scipy = dilation_scipy(image, size=(3, 3), structure=strel)

plt.figure()
plt.imshow(dilated_image_scipy, cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Dilated image - Scipy')

# PyTorch
image_tensor = torch.tensor(image, dtype=torch.float)
strel_tensor = torch.tensor(strel, dtype=torch.float)
dilated_image_pytorch = dilation_pytorch(image_tensor, strel_tensor, origin=origin, border_value=-1000)

plt.figure()
plt.imshow(dilated_image_pytorch.cpu().numpy(), cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Dilated image - PyTorch')

plt.show()
Run Code Online (Sandbox Code Playgroud)

Scipy 文档中提出的原始图像

scipy 的放大图像

pytorch 的放大图像

关于起源的思考

原点是膨胀和腐蚀的关键参数。它可以移动图像。如果您希望图像不移动,则应将其放置在中间(这意味着具有奇数尺寸的结构元素)。我尝试在 scipy 中使用它,但效果不是很好,因为它在所有维度上都是相同的(这在处理非方形结构元素时会带来问题)。我展示的代码正确地考虑了起源。