使用itertools.groupby性能进行Numpy分组

Don*_*nny 26 python algorithm numpy

我有许多大型(> 35,000,000)整数列表,它们将包含重复项.我需要计算列表中每个整数的计数.以下代码有效,但似乎很慢.任何人都可以使用Python更好的基准测试,最好是Numpy吗?

def group():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

if __name__=='__main__':
    from timeit import Timer
    t = Timer("group()","from __main__ import group")
    print t.timeit(number=1)
Run Code Online (Sandbox Code Playgroud)

返回:

$ python bench.py 
111.377498865
Run Code Online (Sandbox Code Playgroud)

干杯!

根据回复进行编辑:

def group_original():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

def group_gnibbler():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    groups = ((k,sum(1 for i in g)) for k,g in groupby(values))
    index = np.fromiter(groups,dtype='u4,u2')

def group_christophe():
    import numpy as np
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    counts=values.searchsorted(values, side='right') - values.searchsorted(values, side='left')
    index = np.zeros(len(values),dtype='u4,u2')
    index['f0']=values
    index['f1']=counts
    #Erroneous result!

def group_paul():
    import numpy as np
    values = np.array(np.random.randint(0,1<<32,size=35000000),dtype='u4')
    values.sort()
    diff = np.concatenate(([1],np.diff(values)))
    idx = np.concatenate((np.where(diff)[0],[len(values)]))
    index = np.empty(len(idx)-1,dtype='u4,u2')
    index['f0']=values[idx[:-1]]
    index['f1']=np.diff(idx)

if __name__=='__main__':
    from timeit import Timer
    timings=[
                ("group_original","Original"),
                ("group_gnibbler","Gnibbler"),
                ("group_christophe","Christophe"),
                ("group_paul","Paul"),
            ]
    for method,title in timings:
        t = Timer("%s()"%method,"from __main__ import %s"%method)
        print "%s: %s secs"%(title,t.timeit(number=1))
Run Code Online (Sandbox Code Playgroud)

返回:

$ python bench.py 
Original: 113.385262966 secs
Gnibbler: 71.7464978695 secs
Christophe: 27.1690568924 secs
Paul: 9.06268405914 secs
Run Code Online (Sandbox Code Playgroud)

尽管Christophe目前给出的结果不正确

Pau*_*aul 31

做这样的事情我得到了3倍的改善:

def group():
    import numpy as np
    values = np.array(np.random.randint(0,3298,size=35000000),dtype='u4')
    values.sort()
    dif = np.ones(values.shape,values.dtype)
    dif[1:] = np.diff(values)
    idx = np.where(dif>0)
    vals = values[idx]
    count = np.diff(idx)
Run Code Online (Sandbox Code Playgroud)

  • 可以通过使用行`idx = np.concatenate((np.where(dif)[0],[len(values)]))和`vals = values [idx [: - 1]]来改变错误. `分别代替第3行和第2行.这真的是使用numpy的最佳答案.如果你想要它更快,我建议使用Cython.这很容易做到,并且在这个numpy代码上显着提高了速度和内存. (2认同)

Ali*_*Ali 11

保罗的回答被接受已经过去了5年多.有趣的sort()是,它仍然是公认解决方案的瓶颈.

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           @profile
     4                                           def group_paul():
     5         1        99040  99040.0      2.4      import numpy as np
     6         1       305651 305651.0      7.4      values = np.array(np.random.randint(0, 2**32,size=35000000),dtype='u4')
     7         1      2928204 2928204.0    71.3      values.sort()
     8         1        78268  78268.0      1.9      diff = np.concatenate(([1],np.diff(values)))
     9         1       215774 215774.0      5.3      idx = np.concatenate((np.where(diff)[0],[len(values)]))
    10         1           95     95.0      0.0      index = np.empty(len(idx)-1,dtype='u4,u2')
    11         1       386673 386673.0      9.4      index['f0'] = values[idx[:-1]]
    12         1        91492  91492.0      2.2      index['f1'] = np.diff(idx)
Run Code Online (Sandbox Code Playgroud)

接受的解决方案在我的机器上运行4.0秒,基数排序下降到1.7秒.

只需切换到基数排序,我就可以获得2.35倍的加速.在这种情况下,基数排序比快速排序快4倍.

请参阅如何比quicksort更快地对整数数组进行排序?这是出于你的问题.

  • @danijar 是的,这很酷。:) 它是 [line_profiler 和 kernprof](https://github.com/rkern/line_profiler)。 (2认同)

Jus*_*eel 5

根据要求,这是一个 Cython 版本。我做了两次遍历数组。第一个找出有多少唯一元素,以便我的数组可以获取适当大小的唯一值和计数。

import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
def dogroup():
    cdef unsigned long tot = 1
    cdef np.ndarray[np.uint32_t, ndim=1] values = np.array(np.random.randint(35000000,size=35000000),dtype=np.uint32)
    cdef unsigned long i, ind, lastval
    values.sort()
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            tot += 1
    cdef np.ndarray[np.uint32_t, ndim=1] vals = np.empty(tot,dtype=np.uint32)
    cdef np.ndarray[np.uint32_t, ndim=1] count = np.empty(tot,dtype=np.uint32)
    vals[0] = values[0]
    ind = 1
    lastval = 0
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            vals[ind] = values[i]
            count[ind-1] = i - lastval
            lastval = i
            ind += 1
    count[ind-1] = len(values) - lastval
Run Code Online (Sandbox Code Playgroud)

到目前为止,排序实际上在这里花费的时间最多。使用我的代码中给出的值数组,排序需要 4.75 秒,而实际查找唯一值和计数需要 0.67 秒。使用使用 Paul 代码的纯 Numpy 代码(但具有相同形式的值数组)以及我在评论中建议的修复,查找唯一值和计数需要 1.9 秒(当然排序仍然需要相同的时间)。

排序占用大部分时间是有意义的,因为它是 O(N log N) 并且计数是 O(N)。您可以比 Numpy 稍微加快排序(如果我没记错的话,它使用 C 的 qsort),但是您必须真正知道自己在做什么,这可能不值得。此外,可能有一些方法可以稍微加快我的 Cython 代码的速度,但这可能不值得。