Rus*_*shi 5 python numpy numpy-einsum
我正在尝试将一些代码从 MATLAB 移植到 Python,但 Python 的性能却慢得多。我不太擅长 Python 编码,因此任何加快这些速度的建议将不胜感激。
我尝试了einsum单行(在我的机器上需要7.5 秒):
import numpy as np
n = 4
N = 200
M = 100
X = 0.1*np.random.rand(M, n, N)
w = 0.1*np.random.rand(M, N, 1)
G = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ijk,ljn->ilkn',X,X)), w)
Run Code Online (Sandbox Code Playgroud)
我也尝试了一个matmult实现(在我的机器上需要6 秒)
G = np.zeros((M, M))
for i in range(M):
G[:, i] = np.squeeze(w[i,...].T @ (np.exp(X[i, :, :].T @ X) @ w))
Run Code Online (Sandbox Code Playgroud)
但我原来的 MATLAB 代码要快得多(在我的机器上需要1 秒)
n = 4;
N = 200;
M = 100;
X = 0.1*rand(n, N, M);
w = 0.1*rand(N, 1, M);
G=zeros(M);
for i=1:M
G(:,i) = squeeze(pagemtimes(pagemtimes(w(:,1,i).', exp(pagemtimes(X(:,:,i),'transpose',X,'none'))) ,w));
end
Run Code Online (Sandbox Code Playgroud)
我原以为这两种 Python 实现在速度上具有可比性,但事实并非如此。为什么 Python 实现这么慢,有什么想法吗?或者有什么加快速度的建议吗?
首先有一个默认设置的np.einsum参数optimizeFalse(主要是因为在某些情况下优化可能比计算更昂贵,并且通常最好先在单独的调用中预先计算最佳路径)。您可以使用optimal=True它来显着加速np.einsum(在这种情况下它提供了最佳路径,尽管内部实现不是最佳的)。请注意,pagemtimesMatlab 中的参数更加具体,np.einsum因此不需要这样的参数(即在这种情况下默认情况下它很快)。
此外,Numpy 的功能类似于np.exp默认创建一个新数组。问题是就地计算数组通常更快(而且消耗的内存也更少)。这可以通过参数来完成out。
在大多数机器上它np.exp相当昂贵,因为它串行运行(像大多数 Numpy 函数一样)并且内部通常也不是很优化。使用像英特尔这样的快速数学库会有所帮助。我怀疑 Matlab 在内部使用了这种快速数学库。或者,可以使用多个线程来更快地计算。使用该包很容易做到这一点numexpr。
以下是更优化的 Numpy 代码:
import numpy as np
import numexpr as ne
# [...] Same initialization as in the question
tmp = np.einsum('ijk,ljn->ilkn',X,X, optimize=True)
ne.evaluate('exp(tmp)', out=tmp)
G = np.einsum('ijk,iljm,lmn->il', w, tmp, w, optimize=True)
Run Code Online (Sandbox Code Playgroud)
以下是我的机器上的结果(使用 i5-9600KF CPU、32 GiB RAM、Windows):
Naive einsums: 6.62 s
CPython loops: 3.37 s
This answer: 1.27 s <----
max9111 solution: 0.47 s (using an unmodified Numba v0.57)
max9111 solution: 0.54 s (using a modified Numba v0.57)
Run Code Online (Sandbox Code Playgroud)
优化后的代码比初始代码快约 5.2 倍,比初始最快代码快 2.7 倍!
第一个einsum在我的机器上更快地实现,占用了运行时间的很大一部分。这主要是因为einsum在内部以效率不高的方式执行许多小矩阵乘法。事实上,每个矩阵乘法都是由 BLAS 库并行完成的(例如 OpenBLAS 库,它是像我这样的大多数机器上的默认库)。问题是 OpenBLAS 并行计算小矩阵的效率不高。事实上,并行计算每个小矩阵的效率并不高。更有效的解决方案是并行计算所有矩阵乘法(每个线程应该执行多个串行矩阵乘法)。这当然就是 Matlab 所做的事情,也是为什么它可以更快一点的原因。这可以使用并行 Numba 代码(或使用 Cython)并禁用 BLAS 例程的并行执行来完成(请注意,如果全局完成,这可能会对较大的脚本产生性能副作用)。
另一种可能的优化是使用多个线程在 Numba 中一次性执行所有操作。该解决方案当然可以减少更多的内存占用并进一步提高性能。然而,编写优化的实现绝非易事,而且生成的代码将更加难以维护。这就是 max9111 的代码的作用。
| 归档时间: |
|
| 查看次数: |
287 次 |
| 最近记录: |