在 pytorch 中使用可分离的 2D 卷积实现 3D 高斯模糊

lop*_*ded 5 convolution gaussianblur pytorch

我正在尝试在 pytorch 中实现 3D 体积的类似高斯的模糊。我可以很容易地通过与 2D 高斯核进行卷积来对 2D 图像进行 2D 模糊,并且相同的方法似乎也适用于使用 3D 高斯核的 3D。然而,它在 3D 中非常慢(特别是对于较大的 sigmas/kernel 大小)。我知道这也可以通过与 2D 内核进行 3 次卷积来完成,这应该要快得多,但我无法让它工作。我的测试用例如下。

import torch
import torch.nn.functional as F

VOL_SIZE = 21


def make_gaussian_kernel(sigma):
    ks = int(sigma * 5)
    if ks % 2 == 0:
        ks += 1
    ts = torch.linspace(-ks // 2, ks // 2 + 1, ks)
    gauss = torch.exp((-(ts / sigma)**2 / 2))
    kernel = gauss / gauss.sum()

    return kernel


def test_3d_gaussian_blur(blur_sigma=2):
    # Make a test volume
    vol = torch.zeros([VOL_SIZE] * 3)
    vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = make_gaussian_kernel(blur_sigma)
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 2D convolution
    vol_in = vol.reshape(1, *vol.shape)
    k2d = torch.einsum('i,j->ij', k, k)
    k2d = k2d / k2d.sum()
    k2d = k2d.expand(VOL_SIZE, 1, *k2d.shape)
    for i in range(3):
        vol_in = vol_in.permute(0, 3, 1, 2)
        vol_in = F.conv2d(vol_in, k2d, stride=1, padding=len(k) // 2, groups=VOL_SIZE)
    vol_3d_sep = vol_in

    torch.allclose(vol_3d, vol_3d_sep)  # --> False
Run Code Online (Sandbox Code Playgroud)

任何帮助将非常感激!

fla*_*awr 3

理论上,您可以使用三个 2d 卷积来计算 3d 高斯卷积,但这意味着您必须减小 2d 内核的大小,因为您在每个方向上有效地进行了两次卷积。

但计算上更有效的(也是你通常想要的)是将其分离为一维内核。我更改了函数的第二部分来实现此目的。(我必须说我真的很喜欢你基于排列的方法!)由于你使用的是 3d 体积,所以你不能真正很好地使用 or 函数conv2dconv1d所以最好的事情就是使用,conv3d即使你只是计算 1d -卷积。

请注意,使用此方法未达到的allclose阈值,可能是由于取消错误。1e-8

def test_3d_gaussian_blur(blur_sigma=2):
    # Make a test volume
    vol = torch.randn([VOL_SIZE] * 3) # using something other than zeros
    vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = make_gaussian_kernel(blur_sigma)
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 1D convolution
    vol_in = vol[None, None, ...]
    # k2d = torch.einsum('i,j->ij', k, k)
    # k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
    # print(f'{k2d.sum()=}')
    k1d = k[None, None, :, None, None]
    for i in range(3):
        vol_in = vol_in.permute(0, 1, 4, 2, 3)
        vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
    vol_3d_sep = vol_in
    print((vol_3d- vol_3d_sep).abs().max()) # something ~1e-7
    print(torch.allclose(vol_3d, vol_3d_sep)) # allclose checks if it is around 1e-8
Run Code Online (Sandbox Code Playgroud)

附录:如果你真的想滥用conv2d来处理卷,你可以尝试

# separate 3d kernel into 1d + 2d
vol_in = vol[None, None, ...]
k2d = torch.einsum('i,j->ij', k, k)
k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
# k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
# print(f'{k2d.sum()=}')
k1d = k[None, None, :, None, None]
vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
vol_in = vol_in[0, ...]
# abuse conv2d-groups argument for volume dimension, works only for 1 channel volumes
vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
vol_3d_sep = vol_in
Run Code Online (Sandbox Code Playgroud)

或者专门 conv2d使用你可以这样做:

# separate 3d kernel into 1d + 2d
vol_in = vol[None,  ...]
# 1d kernel
k1d = k[None, None, :,  None]
k1d = k1d.expand(VOL_SIZE, 1, len(k), 1)
# 2d kernel
k2d = torch.einsum('i,j->ij', k, k)
k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
vol_in = vol_in.permute(0, 2, 1, 3)
vol_in = F.conv2d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0), groups=VOL_SIZE)
vol_in = vol_in.permute(0, 2, 1, 3)
vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
vol_3d_sep = vol_in
Run Code Online (Sandbox Code Playgroud)

这些仍然应该比三个连续的二维卷积更快。