在numpy数组中查找条件索引的最快方法

wea*_*guy 5 python performance numpy cython

我试图找到在二维 numpy 数组上获得 numpy 'where' 语句功能的最快方法;即检索满足条件的索引。它比我使用过的其他语言(例如 IDL、Matlab)慢得多。

我已经一个在嵌套 for 循环中遍历数组的函数进行了cythonized。速度几乎提高了一个数量级,但如果可能的话,我想进一步提高性能。

测试.py:

from cython_where import *
import time
import numpy as np

data = np.zeros((2600,5200))
data[100:200,100:200] = 10

t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0

t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
Run Code Online (Sandbox Code Playgroud)

我的 cython_where.pyx 程序:

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython

DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)

def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
  assert data.dtype == DTYPE1

  cdef int xmax = data.shape[0]
  cdef int ymax = data.shape[1]
  cdef unsigned int x, y
  cdef int count = 0
  cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
  cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
  if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
    for x in xrange(xmax):
    for y in xrange(ymax):
      if(data[x,y] == val):
        xind[count] = x
        yind[count] = y
        count += 1

 return tuple([xind[0:count],yind[0:count]]),count
Run Code Online (Sandbox Code Playgroud)

TEST.py 的输出: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166

我也试过 numpy's argwhere,它和where. 我对 numpy 和 cython 还很陌生,所以如果您有任何其他想法可以真正提高性能,我会全力以赴!

B. *_* M. 3

贡献:

\n\n
    \n
  • Numpy 可以在扁平数组上加速,以获得 4 倍的增益:

    \n\n
    %timeit np.where(data==10)\n1 loops, best of 3: 105 ms per loop\n\n%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)\n10 loops, best of 3: 26.0 ms per loop\n
    Run Code Online (Sandbox Code Playgroud)
  • \n
\n\n

我认为你可以用它来优化你的 cython 代码,避免k=i*ncol+j对每个单元进行计算。

\n\n
    \n
  • Numba 给出了一个简单的替代方案:

    \n\n
    from numba import jit\ndtype=data.dtype\n@jit(nopython=True)\ndef numbaeq(flatdata,x,nrow,ncol):\n  size=ncol*nrow\n  ix=np.empty(size,dtype=dtype)\n  jx=np.empty(size,dtype=dtype)\n  count=0\n  k=0\n  while k<size:\n    if flatdata[k]==x :\n      ix[count]=k//ncol\n      jx[count]=k%ncol\n      count+=1\n    k+=1          \n  return ix[:count],jx[:count]\n\ndef whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)\n
    Run Code Online (Sandbox Code Playgroud)
  • \n
\n\n

这使 :

\n\n
    %timeit whereequal(data,10)\n    10 loops, best of 3: 20.2 ms per loop\n
Run Code Online (Sandbox Code Playgroud)\n\n

在 cython 性能下,numba 在此类问题上没有很好的优化。

\n\n
    \n
  • k//ncol并且k%ncol可以通过优化操作同时计算divmod
  • \n
  • 最终步骤是汇编语言和 parall\xc3\xa9lizes ,但它是其他运动。
  • \n
\n