Uld*_*tre 6 python numpy matrix
当我计算X具有N行和n列的矩阵的三阶矩时,我通常使用einsum:
M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N
Run Code Online (Sandbox Code Playgroud)
这通常很好,但现在我正在使用更大的值,即n = 120和N = 100000,并einsum返回以下错误:
ValueError:迭代器太大
做3个嵌套循环的替代方案是不可行的,所以我想知道是否有任何替代方案.
请注意,计算这至少需要执行 ~n 3 × N = 1730 亿次操作(不考虑对称性),因此除非 numpy 可以访问 GPU 或其他东西,否则它会很慢。在具有 ~3 GHz CPU 的现代计算机上,假设没有 SIMD/并行加速,预计整个计算需要大约 60 秒才能完成。
对于测试,让我们从 N = 1000 开始。我们将使用它来检查正确性和性能:
#!/usr/bin/env python3
import numpy
import time
numpy.random.seed(0)
n = 120
N = 1000
X = numpy.random.random((N, n))
start_time = time.time()
M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
end_time = time.time()
print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)
Run Code Online (Sandbox Code Playgroud)
这大约需要 8 秒。这是基线。
让我们从 3 个嵌套循环作为替代方案开始:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds
Run Code Online (Sandbox Code Playgroud)
这大约需要半分钟,不好!一个原因是因为这实际上是四个嵌套循环:numpy.sum也可以认为是一个循环。
我们注意到总和可以变成点积以消除第四个循环:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds
Run Code Online (Sandbox Code Playgroud)
现在好多了,但还是很慢。但是我们注意到点积可以变成矩阵乘法来消除一个循环:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds
Run Code Online (Sandbox Code Playgroud)
嗯?现在这比einsum! 我们还可以检查答案是否确实正确。
我们可以走得更远吗?是的!我们可以k通过以下方式消除循环:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = numpy.repeat(X[:,j], n).reshape((N, n))
M3[j] = (Y * X).T @ X
# ~0.3 seconds
Run Code Online (Sandbox Code Playgroud)
我们还可以使用广播(即a * [b,c] == [a*b, a*c]X 的每一行)来避免这样做numpy.repeat(感谢 @Divakar):
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = X[:,j].reshape((N, 1))
## or, equivalently:
# Y = X[:, numpy.newaxis, j]
M3[j] = (Y * X).T @ X
# ~0.16 seconds
Run Code Online (Sandbox Code Playgroud)
如果我们将其缩放为 N = 100000,则程序预计需要 16 秒,这在理论限制范围内,因此消除它j可能没有太大帮助(但这可能会使代码真的很难理解)。我们可以接受这是最终解决方案。
注意:如果您使用的是 Python 2,a @ b则相当于a.dot(b).
| 归档时间: |
|
| 查看次数: |
523 次 |
| 最近记录: |