在numpy中乘以对数概率矩阵的数值稳定方法

mar*_*art 34 python numpy logarithm matrix matrix-multiplication

我需要获取包含对数概率的两个NumPy矩阵(或其他2d数组)的矩阵乘积.np.log(np.dot(np.exp(a), np.exp(b)))出于显而易见的原因,天真的方式不是优选的.

运用

from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
    # broadcast b[:,n] over rows of a, sum columns
    res[:, n] = logsumexp(a + b[:, n].T, axis=1) 
Run Code Online (Sandbox Code Playgroud)

工作但运行速度比慢100倍 np.log(np.dot(np.exp(a), np.exp(b)))

运用

logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T
Run Code Online (Sandbox Code Playgroud)

或者其他瓦片和重塑的组合也起作用,但是比上面的循环运行得更慢,因为实际大小的输入矩阵需要非常大量的存​​储器.

我目前正在考虑在C中编写一个NumPy扩展来计算它,但当然我宁愿避免这种情况.是否有既定的方法来执行此操作,或者是否有人知道执行此计算的内存密集程度较低的方法?

编辑: 感谢larsmans提供此解决方案(参见下面的推导):

def logdot(a, b):
    max_a, max_b = np.max(a), np.max(b)
    exp_a, exp_b = a - max_a, b - max_b
    np.exp(exp_a, out=exp_a)
    np.exp(exp_b, out=exp_b)
    c = np.dot(exp_a, exp_b)
    np.log(c, out=c)
    c += max_a + max_b
    return c
Run Code Online (Sandbox Code Playgroud)

logdot_old使用iPython的magic %timeit函数快速比较此方法与上面发布的方法()会产生以下结果:

In  [1] a = np.log(np.random.rand(1000,2000))

In  [2] b = np.log(np.random.rand(2000,1500))

In  [3] x = logdot(a, b)

In  [4] y = logdot_old(a, b) # this takes a while

In  [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False

In  [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop

In  [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop
Run Code Online (Sandbox Code Playgroud)

显然larsmans的方法抹杀了我的!

Fre*_*Foo 22

logsumexp 通过评估等式的右边来工作

log(? exp[a]) = max(a) + log(? exp[a - max(a)])
Run Code Online (Sandbox Code Playgroud)

即,它在开始求和之前拉出最大值,以防止溢出exp.在做矢量点积之前可以应用相同的方法:

log(exp[a] ? exp[b])
 = log(? exp[a] × exp[b])
 = log(? exp[a + b])
 = max(a + b) + log(? exp[a + b - max(a + b)])     { this is logsumexp(a + b) }
Run Code Online (Sandbox Code Playgroud)

但是通过在推导中采取不同的转向,我们得到了

log(? exp[a] × exp[b])
 = max(a) + max(b) + log(? exp[a - max(a)] × exp[b - max(b)])
 = max(a) + max(b) + log(exp[a - max(a)] ? exp[b - max(b)])
Run Code Online (Sandbox Code Playgroud)

最终形式的内部有一个矢量点积.它也很容易扩展到矩阵乘法,所以我们得到了算法

def logdotexp(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C
Run Code Online (Sandbox Code Playgroud)

这会产生两个A大小的临时数和两个B大小的临时数,但每个临时数都可以消除

exp_A = A - max_A
np.exp(exp_A, out=exp_A)
Run Code Online (Sandbox Code Playgroud)

并且类似地B.(如果函数可以修改输入矩阵,则可以消除所有临时值.)

  • 这不如最初的较慢解决方案稳定。考虑 logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]])。 (2认同)

Has*_*san 5

假设A.shape==(n,r)B.shape==(r,m)。在计算矩阵乘积时C=A*B,实际上存在n*m求和。为了在对数空间中工作时获得稳定的结果,您需要在每个求和中使用 logsumexp 技巧。幸运的是,使用 numpy 广播可以很容易地分别控制 A 和 B 的行和列的稳定性。

这是代码:

def logdotexp(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C
Run Code Online (Sandbox Code Playgroud)

笔记:

这背后的原因与 FredFoo 的答案类似,但他为每个矩阵使用了一个最大值。由于他没有考虑每一项n*m求和,因此最终矩阵的某些元素可能仍然不稳定,如评论之一所述。

使用 @identity-m 反例与当前接受的答案进行比较:

def logdotexp_less_stable(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

print('old method:')
print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
print('new method:')
print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
Run Code Online (Sandbox Code Playgroud)

打印

old method:
[[      -inf 0.69314718]
 [      -inf 0.69314718]]
new method:
[[-9.99306853e+02  6.93147181e-01]
 [-9.99306853e+02  6.93147181e-01]]
Run Code Online (Sandbox Code Playgroud)