试图理解计算 Torch 中 LogSoftMax 输入的梯度的代码

lar*_*ars 0 mathematical-optimization gradient-descent torch softmax

代码来自:https : //github.com/torch/nn/blob/master/lib/THNN/generic/LogSoftMax.c

我没有看到这段代码如何计算模块 LogSoftMax 输入的梯度。我感到困惑的是两个 for 循环在做什么。

for (t = 0; t < nframe; t++)
{
sum = 0;
gradInput_data = gradInput_data0 + dim*t;
output_data = output_data0 + dim*t;
gradOutput_data = gradOutput_data0 + dim*t;

for (d = 0; d < dim; d++)
  sum += gradOutput_data[d];

for (d = 0; d < dim; d++)
  gradInput_data[d] = gradOutput_data[d] - exp(output_data[d])*sum;
 }
}
Run Code Online (Sandbox Code Playgroud)

del*_*eil 5

在前进时,我们有(x = 输入向量,y = 输出向量,f = logsoftmax,i = 第 i 个分量):

yi = f(xi)
   = log( exp(xi) / sum_j(exp(xj)) )
   = xi - log( sum_j(exp(xj)) )
Run Code Online (Sandbox Code Playgroud)

在计算 f 的雅可比 Jf 时,您有(第 i 行):

dyi/dxi = 1 - exp(xi) / sum_j(exp(xj))
Run Code Online (Sandbox Code Playgroud)

对于与 i 不同的 k:

dyi/dxk = - exp(xk) / sum_j(exp(xj))
Run Code Online (Sandbox Code Playgroud)

这为 Jf 提供:

1-E(x1)     -E(x2)     -E(x3)    ...
 -E(x1)    1-E(x2)     -E(x3)    ...
 -E(x1)     -E(x2)    1-E(x3)    ...
...
Run Code Online (Sandbox Code Playgroud)

E(xi) = exp(xi) / sum_j(exp(xj))

如果我们将 gradInput 命名为梯度 wrt 输入和 gradOutput 梯度 wrt 输出,则反向传播给出(链式法则):

gradInputi = sum_j( gradOutputj . dyj/dxi )
Run Code Online (Sandbox Code Playgroud)

这相当于:

gradInput = transpose(Jf) . gradOutput
Run Code Online (Sandbox Code Playgroud)

最后给出了第 i 个分量:

gradInputi = gradOutputi - E(xi) . sum_j( gradOutputj )
Run Code Online (Sandbox Code Playgroud)

因此,第一个循环预先计算sum_j( gradOutputj ),最后一个循环预先计算上述术语,即 grad 的第 i 个分量。输入 - 除了1 / sum_j(exp(xj))Torch 实现中缺少指数项(上述演算可能应该仔细检查,即使它听起来正确并解释了当前实现)。

更新缺少 1 / sum_j(exp(xj))术语没有问题。由于 jacobian 是在输出值上计算的,并且由于之前计算的输出恰好是 log-softmax 分布,因此该分布的 sum-exp 为 1:

sum_j(exp(outputj)) = sum_j(exp( log(exp(inputj) / sum_k(exp(inputk) ))
                    = sum_j(         exp(inputj) / sum_k(exp(inputk)  )
                    = 1
Run Code Online (Sandbox Code Playgroud)

所以没有必要在实现中明确这个术语,它给出(对于 x = 输出):

gradInputi = gradOutputi - exp(outputi) . sum_j( gradOutputj )
Run Code Online (Sandbox Code Playgroud)