Ste*_* Lu 2 python numpy numba
我正在使用 Numba 非 python 模式和一些 NumPy 函数。
@njit
def invert(W, copy=True):
'''
Inverts elementwise the weights in an input connection matrix.
In other words, change the from the matrix of internode strengths to the
matrix of internode distances.
If copy is not set, this function will *modify W in place.*
Parameters
----------
W : np.ndarray
weighted connectivity matrix
copy : bool
Returns
-------
W : np.ndarray
inverted connectivity matrix
'''
if copy:
W = W.copy()
E = np.where(W)
W[E] = 1. / W[E]
return W
Run Code Online (Sandbox Code Playgroud)
在这个函数中,W是一个矩阵。但我收到以下错误。可能和W[E] = 1. / W[E]线路有关。
File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))
Run Code Online (Sandbox Code Playgroud)
那么 NumPy 和 Numba 的正确使用方法是什么?我知道 NumPy 在矩阵计算方面做得很好。在这种情况下,NumPy 是否足够快以至于 Numba 不再提供加速?
正如 FBruzzesi 在评论中提到的,代码无法编译的原因是您使用了“花式索引”,因为 inE是W[E]的输出np.where,并且是数组的元组。(这解释了有点神秘的错误消息:Numba 不知道如何使用getitem,即当输入之一是元组时,它不知道如何在括号中查找某些内容。)
Numba实际上支持单个维度上的高级索引(也称为“高级索引”),但不支持多个维度。在您的情况下,这允许进行简单的修改:首先使用ravel几乎无成本地将数组变为一维,然后应用转换,然后进行廉价的reshape返回。
@njit\ndef invert2(W, copy=True):\n if copy:\n W = W.copy()\n Z = W.ravel()\n E = np.where(Z)\n Z[E] = 1. / Z[E]\n return Z.reshape(W.shape)\nRun Code Online (Sandbox Code Playgroud)\n\n但这仍然比需要的慢,因为计算通过不必要的中间数组传递,而不是在遇到非零值时立即修改数组。简单地执行循环会更快:
\n\n@njit \ndef invert3(W, copy=True): \n if copy: \n W = W.copy() \n Z = W.ravel() \n for i in range(len(Z)): \n if Z[i] != 0: \n Z[i] = 1/Z[i] \n return Z.reshape(W.shape) \nRun Code Online (Sandbox Code Playgroud)\n\n无论 的尺寸如何,此代码都可以工作W。如果我们知道它W是二维的,那么我们可以直接迭代这两个维度,但由于两者具有相似的性能,我将采用更通用的路线。
在我的计算机上,假设有一个 300×300 的数组W,其中大约一半的条目是 0,并且invert未经 Numba 编译的原始函数在哪里,时间如下:
In [80]: %timeit invert(W) \n2.67 ms \xc2\xb1 49.3 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100 loops each)\n\nIn [81]: %timeit invert2(W) \n519 \xc2\xb5s \xc2\xb1 24.5 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1000 loops each)\n\nIn [82]: %timeit invert3(W) \n186 \xc2\xb5s \xc2\xb1 11.1 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10000 loops each)\nRun Code Online (Sandbox Code Playgroud)\n\n因此,Numba 为我们提供了相当大的加速(在它已经运行一次以消除编译时间之后),特别是在以 Numba 可以利用的高效循环风格重写代码之后。
\n