高效计算帽子矩阵对角线:inv(X'WX)'X'

use*_*956 5 numpy matrix

作为疾病风险模型的一部分,我正在尝试实现论文(Python/numpy)中的计算,其中一部分是以下矩阵计算:

\n\n

富+巴兹

\n\n

在哪里:

\n\n
    \n
  • X 和 [nm]。m 大小合理 (1..200),但 n 相当大 (>500k)
  • \n
  • W 是一个 [nn],并且过大,但仅对角线上有非零值
  • \n
\n\n

另外,我只需要获取 Q 的对角元素作为输出。

\n\n

是否有一些 numpy 矩阵魔法可以让我有效地计算这个?

\n\n

笔记:

\n\n

论文在R中有一个实现实现,我(相信)这样做如下:

\n\n
Qdiag <- lm.influence(lm(y ~ X-1, weights=W))$hat/W\n
Run Code Online (Sandbox Code Playgroud)\n\n

R\'s lm.influence$hat 文档表示这给出了“一个包含 \xe2\x80\x98hat\xe2\x80\x99 矩阵对角线的向量”。尽管维基百科的定义听起来有点像我想要的对帽子矩阵(==影响或投影矩阵)的定义看起来略有不同。

\n\n

--

\n\n

我认为以下是一个有效的(天真的)实现。对于大 n 来说内存不足

\n\n
m = 3\nn = 20 # 500000 -- out of memory for large n\n\nnp.random.seed(123)\nX = np.random.random((n,m))\nW = np.random.random(n)\nW = np.diag(W)\nxtwx = X.T.dot(W.dot(X))\nxtwxi = np.linalg.pinv(xtwx)\nxtwxixt = xtwxi.dot(X.T)\nQ = X.dot(xtwxixt)\nQdiag = np.diag(Q)\nprint Qdiag.shape, Qdiag.sum() # Checksum of output \nprint Qdiag\n
Run Code Online (Sandbox Code Playgroud)\n

sas*_*cha 3

(在现已删除的评论中,我说鉴于某些密度和硬件假设将其视为黑盒,这是不可能的。但似乎可以做到。这并不意味着这是正确的方法!)

因此,在不分析这个公式的背景的情况下,我们可以在给出最小假设和经典规则的情况下做一些基本方法,例如:

  • A:矩阵乘法的结合律
  • B:求解线性方程组而不是求逆
    • 我们假设 XtWX 是非奇异的
  • C:识别 A*W(仅 W 对角线)是与对角向量的逐行元素乘积
  • D:仅计算 Q 的对角线条目(否则我们将得到一个 N*N = 2.5e8 个数字条目的结果矩阵)

代码:

import numpy as np
from time import perf_counter as pc     # python 3 only

m = 200
n = 500000

np.random.seed(123)
X = np.random.random((n,m))
W_diag = np.random.random(n)            # C -> dense vector

start_time = pc()

lhs = np.multiply(X.T, W_diag).dot(X)   # C (+A)
x = np.linalg.solve(lhs, X.T)           # B

# EDIT: Paul Panzer recommends the inverse in his comment based on the rhs-dims!

# if you know something about lhs (looks symmetric; maybe even PSD)
# use one of the following for speedups and more robustness
# i highly recommend this research: feels PSD
# import scipy.linalg as slin
# x = slin.solve(lhs, X.T, assume_a='sym')
# x = slin.solve(lhs, X.T, assume_a='pos')

Q_ = np.einsum('ij,ji->i', X,x)         # D most important -> N*N elsewise
print(Q_)

end_time = pc()
print(end_time - start_time)
Run Code Online (Sandbox Code Playgroud)

出去:

[ 0.00068103  0.00083676  0.00080945 ...,  0.00077864  0.00078945
  0.0007804 ]
3.1077745566331165  # seconds
Run Code Online (Sandbox Code Playgroud)

与为单个测试用例给出的代码相比,结果是相同的!

一般来说,我建议遵循基础数学,而不是提取的公式本身。由于该论文基本上说更完整的问题是加权最小二乘问题,因此这对于某些研究来说是一个良好的开端。

  • 仅当要求解的列数不太大时,求解才会更便宜。在给定的场景中,情况显然并非如此,您会发现简单地求逆实际上比求解更快。 (2认同)