wea*_*guy 5 python performance numpy cython
我试图找到在二维 numpy 数组上获得 numpy 'where' 语句功能的最快方法;即检索满足条件的索引。它比我使用过的其他语言(例如 IDL、Matlab)慢得多。
我已经对一个在嵌套 for 循环中遍历数组的函数进行了cythonized。速度几乎提高了一个数量级,但如果可能的话,我想进一步提高性能。
测试.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
Run Code Online (Sandbox Code Playgroud)
我的 cython_where.pyx 程序:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
Run Code Online (Sandbox Code Playgroud)
TEST.py 的输出:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
我也试过 numpy's argwhere
,它和where
. 我对 numpy 和 cython 还很陌生,所以如果您有任何其他想法可以真正提高性能,我会全力以赴!
贡献:
\n\nNumpy 可以在扁平数组上加速,以获得 4 倍的增益:
\n\n%timeit np.where(data==10)\n1 loops, best of 3: 105 ms per loop\n\n%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)\n10 loops, best of 3: 26.0 ms per loop\n
Run Code Online (Sandbox Code Playgroud)我认为你可以用它来优化你的 cython 代码,避免k=i*ncol+j
对每个单元进行计算。
Numba 给出了一个简单的替代方案:
\n\nfrom numba import jit\ndtype=data.dtype\n@jit(nopython=True)\ndef numbaeq(flatdata,x,nrow,ncol):\n size=ncol*nrow\n ix=np.empty(size,dtype=dtype)\n jx=np.empty(size,dtype=dtype)\n count=0\n k=0\n while k<size:\n if flatdata[k]==x :\n ix[count]=k//ncol\n jx[count]=k%ncol\n count+=1\n k+=1 \n return ix[:count],jx[:count]\n\ndef whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)\n
Run Code Online (Sandbox Code Playgroud)这使 :
\n\n %timeit whereequal(data,10)\n 10 loops, best of 3: 20.2 ms per loop\n
Run Code Online (Sandbox Code Playgroud)\n\n在 cython 性能下,numba 在此类问题上没有很好的优化。
\n\nk//ncol
并且k%ncol
可以通过优化操作同时计算divmod
。