wea*_*guy 5 python performance numpy cython
我试图找到在二维 numpy 数组上获得 numpy 'where' 语句功能的最快方法;即检索满足条件的索引。它比我使用过的其他语言(例如 IDL、Matlab)慢得多。
我已经对一个在嵌套 for 循环中遍历数组的函数进行了cythonized。速度几乎提高了一个数量级,但如果可能的话,我想进一步提高性能。
测试.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
我的 cython_where.pyx 程序:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
  assert data.dtype == DTYPE1
  cdef int xmax = data.shape[0]
  cdef int ymax = data.shape[1]
  cdef unsigned int x, y
  cdef int count = 0
  cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
  cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
  if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
    for x in xrange(xmax):
    for y in xrange(ymax):
      if(data[x,y] == val):
        xind[count] = x
        yind[count] = y
        count += 1
 return tuple([xind[0:count],yind[0:count]]),count
TEST.py 的输出:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
我也试过 numpy's argwhere,它和where. 我对 numpy 和 cython 还很陌生,所以如果您有任何其他想法可以真正提高性能,我会全力以赴!
贡献:
\n\nNumpy 可以在扁平数组上加速,以获得 4 倍的增益:
\n\n%timeit np.where(data==10)\n1 loops, best of 3: 105 ms per loop\n\n%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)\n10 loops, best of 3: 26.0 ms per loop\n我认为你可以用它来优化你的 cython 代码,避免k=i*ncol+j对每个单元进行计算。
Numba 给出了一个简单的替代方案:
\n\nfrom numba import jit\ndtype=data.dtype\n@jit(nopython=True)\ndef numbaeq(flatdata,x,nrow,ncol):\n  size=ncol*nrow\n  ix=np.empty(size,dtype=dtype)\n  jx=np.empty(size,dtype=dtype)\n  count=0\n  k=0\n  while k<size:\n    if flatdata[k]==x :\n      ix[count]=k//ncol\n      jx[count]=k%ncol\n      count+=1\n    k+=1          \n  return ix[:count],jx[:count]\n\ndef whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)\n这使 :
\n\n    %timeit whereequal(data,10)\n    10 loops, best of 3: 20.2 ms per loop\n在 cython 性能下,numba 在此类问题上没有很好的优化。
\n\nk//ncol并且k%ncol可以通过优化操作同时计算divmod。