kws*_*wsp 6 python arrays performance numpy numba
我正在尝试加速一段将一维数组(过滤器)与二维数组的每一列进行卷积的代码。不知何故,当我用 numba's 运行它时njit,速度减慢了 7 倍。我的想法:
(在 Windows 10、conda 的 python 3.9.4、numpy 1.12.2、numba 0.53.1 上测试)
\n谁能告诉我为什么这段代码很慢?
\nimport numpy as np\nfrom numba import njit\n\ndef f1(a1, filt):\n l2 = filt.size // 2\n res = np.empty(a1.shape)\n for i in range(a1.shape[1]):\n res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]\n return res\n\n@njit\ndef f1_jit(a1, filt):\n l2 = filt.size // 2\n res = np.empty(a1.shape)\n for i in range(a1.shape[1]):\n res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]\n return res\n\na1 = np.random.random((6400, 1000))\nfilt = np.random.random((65))\nf1(a1, filt)\nf1_jit(a1, filt)\n\n%timeit f1(a1, filt) # 404 ms \xc2\xb1 19.3 ms per loop (mean \xc2\xb1 std. dev. of 7 runs, 1 loop each)\n%timeit f1_jit(a1, filt) # 2.8 s \xc2\xb1 66.7 ms per loop (mean \xc2\xb1 std. dev. of 7 runs, 1 loop each)\nRun Code Online (Sandbox Code Playgroud)\n
Jér*_*ard 12
问题来自于 的 Numba 实现np.convolve。这是一个已知的问题。事实证明,当前的 Numba 实现比 Numpy 慢得多 (在 Windows 上测试的版本 <=0.54.1)。
一方面,Numpy 实现correlate调用本身执行点积,该点积应该由系统上可用的快速 BLAS 库实现。另一方面,Numba 实现调用_get_inner_prod它np.dot也应该使用相同的 BLAS 库(假设检测到 BLAS 应该是这种情况)...
话虽这么说,有多个与点积相关的问题:
\n首先,如果手动禁用 的内部变量,Numba 使用_HAVE_BLAS点积的后备实现,速度应该会慢得多。然而,事实证明,在我的机器上使用后备点积实现的执行速度比使用 BLAS 包装器快 5 倍!另外使用 Numba 装饰器中的参数可使整体执行速度提高 8.7 倍!这是测试代码:numba/np/arraymath.pynp.convolvefastmath=Truenjit
import numpy as np\nimport numba as nb\n\ndef npConvolve(a, b):\n return np.convolve(a, b)\n\n@nb.njit(\'float64[:](float64[:], float64[:])\')\ndef nbConvolveUncont(a, b):\n return np.convolve(a, b)\n\n@nb.njit(\'float64[::1](float64[::1], float64[::1])\')\ndef nbConvolveCont(a, b):\n return np.convolve(a, b)\n\na = np.random.random(6400)\nb = np.random.random(65)\n%timeit -n 100 npConvolve(a, b)\n%timeit -n 100 nbConvolveUncont(a, b)\n%timeit -n 100 nbConvolveCont(a, b)\nRun Code Online (Sandbox Code Playgroud)\n以下是原始有趣的结果:
\nWith _HAVE_BLAS=True (default):\n126 \xc2\xb5s \xc2\xb1 292 ns per loop\n1.6 ms \xc2\xb1 21.3 \xc2\xb5s per loop\n1.6 ms \xc2\xb1 18.5 \xc2\xb5s per loop\n\nWith _HAVE_BLAS=False:\n125 \xc2\xb5s \xc2\xb1 359 ns per loop\n311 \xc2\xb5s \xc2\xb1 1.18 \xc2\xb5s per loop\n268 \xc2\xb5s \xc2\xb1 4.26 \xc2\xb5s per loop\n\nWith _HAVE_BLAS=False and fastmath=True:\n125 \xc2\xb5s \xc2\xb1 757 ns per loop\n327 \xc2\xb5s \xc2\xb1 3.69 \xc2\xb5s per loop\n183 \xc2\xb5s \xc2\xb1 654 ns per loop\nRun Code Online (Sandbox Code Playgroud)\n此外,np_convolveNumba 内部翻转一些数组参数,然后使用具有不平凡步长(即不是 1)的翻转数组执行点积。这种重要的进步可能会对点积性能产生影响。更一般地说,任何阻止编译器知道数组是连续的转换肯定会严重影响性能。事实上,以下测试显示了使用 Numba 点积实现处理连续数组的影响:
import numpy as np\nimport numba as nb\n\ndef np_dot(a, b):\n return np.dot(a, b)\n\n@nb.njit(\'float64(float64[::1], float64[::1])\')\ndef nb_dot_cont(a, b):\n return np.dot(a, b)\n\n@nb.njit(\'float64(float64[::1], float64[:])\')\ndef nb_dot_stride(a, b):\n return np.dot(a, b)\n\nv = np.random.random(128*1024)\n%timeit -n 200 np_dot(v, v) # 36.5 \xc2\xb5s \xc2\xb1 4.9 \xc2\xb5s per loop\n%timeit -n 200 nb_dot_stride(v, v) # 361.0 \xc2\xb5s \xc2\xb1 17.1 \xc2\xb5s per loop (x10 !!!)\n%timeit -n 200 nb_dot_cont(v, v) # 34.1 \xc2\xb5s \xc2\xb1 2.9 \xc2\xb5s per loop\nRun Code Online (Sandbox Code Playgroud)\n请注意,当 Numba 在相当大的数组上工作时,很难加速 Numpy 调用,因为 Numba主要在 Python 中重新实现 Numpy 函数,并使用JIT 编译器(LLVM-Lite) 来加速它们,而 Numpy 主要以普通格式实现 - C(带有相当慢的 Python 包装代码)。Numpy 代码使用SIMD 指令等低级处理器功能来加快许多函数的执行速度。两者似乎都使用已知高度优化的 BLAS 库。Numpy 通常会更加优化,因为 Numpy 目前比 Numba 更成熟:Numpy 拥有更多的贡献者,而且工作时间更长。
\n| 归档时间: |
|
| 查看次数: |
1787 次 |
| 最近记录: |