log_softmax() 如何实现以更快的速度和数值稳定性计算其值(和梯度)?

her*_*h10 4 python machine-learning numerical-methods mxnet pytorch

MXNet 和 PyTorch 都提供了计算 log 的特殊实现(softmax()),速度更快,数值更稳定。但是,我在这两个包中都找不到该函数 log_softmax() 的实际 Python 实现。

谁能解释一下这是如何实现的,或者更好的是,给我指出相关的源代码?

Vũ *_*Anh 11

  • 数值误差:
>>> x = np.array([1, -10, 1000])
>>> np.exp(x) / np.exp(x).sum()
RuntimeWarning: overflow encountered in exp
RuntimeWarning: invalid value encountered in true_divide
Out[4]: array([ 0.,  0., nan])
Run Code Online (Sandbox Code Playgroud)

有两种方法可以避免计算 softmax 时的数值误差:

  • 指数标准化:

在此输入图像描述

def exp_normalize(x):
    b = x.max()
    y = np.exp(x - b)
    return y / y.sum()

>>> exp_normalize(x)
array([0., 0., 1.])
Run Code Online (Sandbox Code Playgroud)
  • 对数和指数

在此输入图像描述

def log_softmax(x):
    c = x.max()
    logsumexp = np.log(np.exp(x - c).sum())
    return x - c - logsumexp

Run Code Online (Sandbox Code Playgroud)

请注意,上式中b、c的合理选择是max(x)。通过这种选择,由于 exp 导致的溢出是不可能的。移位后取幂后的最大数为0。