为什么X.dot(XT)在numpy中需要这么多内存?

blu*_*cat 6 python numpy linear-algebra scipy

X是焦虑矩阵,其中p远大于n.假设n = 1000且p = 500000.当我运行时:

X = np.random.randn(1000,500000)
S = X.dot(X.T)
Run Code Online (Sandbox Code Playgroud)

尽管结果大小为1000 x 1000,但执行此操作最终会占用大量内存.一旦操作完成,内存使用将恢复.有没有办法解决?

ali*_*i_m 6

这个问题是不是XX.T都享有相同的内存空间的本身,而是X.T为F-连续的,而不是C-连续的.当然,在将数组与其转置视图相乘的情况下,对于至少一个输入数组必须如此.

在numpy <1.8中,np.dot将创建任何 F有序输入数组的C有序副本,而不仅仅是恰好在同一内存块上的视图.

例如:

X = np.random.randn(1000,50000)
Y = np.random.randn(50000, 100)

# X and Y are both C-order, no copy
%memit np.dot(X, Y)
# maximum of 1: 485.554688 MB per loop

# make X Fortran order and Y C-order, now the larger array (X) gets
# copied
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 867.070312 MB per loop

# make X C-order and  Y Fortran order, now the smaller array (Y) gets
# copied
X = np.ascontiguousarray(X)
Y = np.asfortranarray(Y)
%memit np.dot(X, Y)
# maximum of 1: 523.792969 MB per loop

# make both of them F-ordered, both get copied!
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 905.093750 MB per loop
Run Code Online (Sandbox Code Playgroud)

如果复制是一个问题(例如,当X非常大时),你能做些什么呢?

最好的选择可能是升级到更新版本的numpy - 正如@perimosocordiae指出的那样,这个性能问题在这个拉取请求中得到了解决.

如果由于某种原因你不能升级numpy,还有一个技巧,允许你执行快速,基于BLAS的点产品,而无需通过直接调用相关的BLAS函数强制复制scipy.linalg.blas(从这个答案无耻地窃取):

from scipy.linalg import blas
X = np.random.randn(1000,50000)

%memit res1 = np.dot(X, X.T)
# maximum of 1: 845.367188 MB per loop

%memit res2 = blas.dgemm(alpha=1., a=X.T, b=X.T, trans_a=True)
# maximum of 1: 471.656250 MB per loop

print np.all(res1 == res2)
# True
Run Code Online (Sandbox Code Playgroud)