her*_*h10 4 python machine-learning numerical-methods mxnet pytorch
MXNet 和 PyTorch 都提供了计算 log 的特殊实现(softmax()),速度更快,数值更稳定。但是,我在这两个包中都找不到该函数 log_softmax() 的实际 Python 实现。
谁能解释一下这是如何实现的,或者更好的是,给我指出相关的源代码?
Vũ *_*Anh 11
>>> x = np.array([1, -10, 1000])
>>> np.exp(x) / np.exp(x).sum()
RuntimeWarning: overflow encountered in exp
RuntimeWarning: invalid value encountered in true_divide
Out[4]: array([ 0., 0., nan])
Run Code Online (Sandbox Code Playgroud)
有两种方法可以避免计算 softmax 时的数值误差:
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
>>> exp_normalize(x)
array([0., 0., 1.])
Run Code Online (Sandbox Code Playgroud)
def log_softmax(x):
c = x.max()
logsumexp = np.log(np.exp(x - c).sum())
return x - c - logsumexp
Run Code Online (Sandbox Code Playgroud)
请注意,上式中b、c的合理选择是max(x)。通过这种选择,由于 exp 导致的溢出是不可能的。移位后取幂后的最大数为0。
| 归档时间: |
|
| 查看次数: |
6544 次 |
| 最近记录: |