使用 pytorch 验证卷积定理

Bar*_*der 4 fft convolution theorem-proving pytorch

基本上这个定理的公式如下:

F(f*g) = F(f)xF(g)

我知道这个定理,但我只是无法使用 pytorch 重现结果。

下面是一个可重现的代码:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)
Run Code Online (Sandbox Code Playgroud)

这是打印的结果(FxG - F_fg)

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])
Run Code Online (Sandbox Code Playgroud)

你可以看到差异并不总是 0。

有人能告诉我为什么以及如何正确地做到这一点吗?

谢谢

jod*_*dag 8

所以我仔细研究了你到目前为止所做的事情。我已经确定了您代码中的三个错误来源。我将在这里尝试充分解决每个问题。

1. 复数运算

PyTorch 目前不支持复数乘法 (AFAIK)。FFT 运算仅返回具有实部和虚部维度的张量。我们需要显式地编写复杂的乘法代码,而不是使用torch.mul*运算符。

(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)

2.卷积的定义

CNN文献中经常使用的“卷积”的定义,实际上与讨论卷积定理时使用的定义不同。我不会详细介绍,但理论定义在滑动和乘法之前翻转内核。相反,pytorch、tensorflow、caffe 等中的卷积操作不会进行这种翻转。

为了解决这个问题,我们可以g在应用 FFT 之前简单地翻转(水平和垂直)。

3. 锚定位置

假设使用卷积定理时的锚点是 padded 的左上角g。同样,我不会详细介绍这一点,但这是数学计算的方式。


举个例子,第二点和第三点可能更容易理解。假设你使用了以下g

[1 2 3]
[4 5 6]
[7 8 9]
Run Code Online (Sandbox Code Playgroud)

而不是g_new

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
Run Code Online (Sandbox Code Playgroud)

它实际上应该是

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]
Run Code Online (Sandbox Code Playgroud)

我们垂直和水平翻转内核,然后应用循环移位,使内核的中心位于左上角。


我最终重写了您的大部分代码并对其进行了一些概括。最复杂的操作是g_new正确定义。我决定使用网格和模运算来同时翻转和移动索引。如果这里的某些内容对您没有意义,请发表评论,我会尽力澄清。

[1 2 3]
[4 5 6]
[7 8 9]
Run Code Online (Sandbox Code Playgroud)

这给了我

Average difference: 4.6866085767760524e-07
Run Code Online (Sandbox Code Playgroud)

这非常接近于零。我们没有得到完全为零的原因仅仅是由于浮点错误。