在numpy中计算超过阈值的数组值的最快方法

use*_*911 6 python arrays performance boolean numpy

我有一个包含10 ^ 8个浮点数的numpy数组,想要计算它们中有多少> =给定的阈值.速度至关重要,因为操作必须在大量此类阵列上完成.参赛者到目前为止

np.sum(myarray >= thresh)

np.size(np.where(np.reshape(myarray,-1) >= thresh))
Run Code Online (Sandbox Code Playgroud)

在大于一个值的矩阵计算所有值的答案表明np.where()会更快,但我发现时序结果不一致.我的意思是,对于某些实现和布尔条件,np.size(np.where(cond))比np.sum(cond)更快,但对于某些实际上它更慢.

具体来说,如果大部分条目满足条件,那么np.sum(cond)明显更快但如果一小部分(可能小于十分之一)那么np.size(np.where(cond))获胜.

问题分为两部分:

  • 还有其他建议吗?
  • np.size(np.where(cond))所花费的时间是否随着cond为真的条目数而增加是否有意义?

M4r*_*ini 3

使用 cython 可能是一个不错的选择。

\n\n
import numpy as np\ncimport numpy as np\ncimport cython\nfrom cython.parallel import prange\n\n\nDTYPE_f64 = np.float64\nctypedef np.float64_t DTYPE_f64_t\n\n\n@cython.boundscheck(False)\n@cython.wraparound(False)\n@cython.nonecheck(False)\ncdef int count_above_cython(DTYPE_f64_t [:] arr_view, DTYPE_f64_t thresh) nogil:\n\n    cdef int length, i, total\n    total = 0\n    length = arr_view.shape[0]\n\n    for i in prange(length):\n        if arr_view[i] >= thresh:\n            total += 1\n\n    return total\n\n\n@cython.boundscheck(False)\n@cython.wraparound(False)\n@cython.nonecheck(False)\ndef count_above(np.ndarray arr, DTYPE_f64_t thresh):\n\n    cdef DTYPE_f64_t [:] arr_view = arr.ravel()\n    cdef int total\n\n    with nogil:\n       total =  count_above_cython(arr_view, thresh)\n    return total\n
Run Code Online (Sandbox Code Playgroud)\n\n

不同提议方法的时机。

\n\n
myarr = np.random.random((1000,1000))\nthresh = 0.33\n\nIn [6]: %timeit count_above(myarr, thresh)\n1000 loops, best of 3: 693 \xc2\xb5s per loop\n\nIn [9]: %timeit np.count_nonzero(myarr >= thresh)\n100 loops, best of 3: 4.45 ms per loop\n\nIn [11]: %timeit np.sum(myarr >= thresh)\n100 loops, best of 3: 4.86 ms per loop\n\nIn [12]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))\n10 loops, best of 3: 61.6 ms per loop\n
Run Code Online (Sandbox Code Playgroud)\n\n

对于更大的数组:

\n\n
In [13]: myarr = np.random.random(10**8)\n\nIn [14]: %timeit count_above(myarr, thresh)\n10 loops, best of 3: 63.4 ms per loop\n\nIn [15]: %timeit np.count_nonzero(myarr >= thresh)\n1 loops, best of 3: 473 ms per loop\n\nIn [16]: %timeit np.sum(myarr >= thresh)\n1 loops, best of 3: 511 ms per loop\n\nIn [17]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))\n1 loops, best of 3: 6.07 s per loop\n
Run Code Online (Sandbox Code Playgroud)\n