numba中的一组int

DNF*_*DNF 5 python performance types numpy numba

我正在计算int8s 向量中最常用的数字.当我设置一个ints 的计数器数组时,Numba抱怨:

@jit(nopython=True)
def freq_int8(y):
    """Find most frequent number in array"""
    count = np.zeros(256, dtype=int)
    for val in y:
        count[val] += 1
    return ((np.argmax(count)+128) % 256) - 128
Run Code Online (Sandbox Code Playgroud)

调用它我收到以下错误:

TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))
Run Code Online (Sandbox Code Playgroud)

如果我删除dtype=int它的工作,我得到一个体面的加速.然而,我很困惑为什么声明一个ints 数组不起作用.有没有已知的解决方法,这里有没有值得拥有的效率增益?

背景:我试图削减一些重量级代码的微秒.我特别受到伤害numpy.median,并一直在调查Numba,但我正在努力改进median.找到最常用的数字是一个可接受的替代方案median,在这里我已经能够获得一些性能.上面的numba代码也快于numpy.bincount.

更新:在接受的答案输入后,这里的实现medianint8载体.它比numpy.median以下快大约一个数量级:

@jit(nopython=True)
def median_int8(y):
    N2 = len(y)//2
    count = np.zeros(256, dtype=np.int32)
    for val in y:
        count[val] += 1
    cs = 0
    for i in range(-128, 128):
        cs += count[i]
        if cs > N2:
            return float(i)
        elif cs == N2:
            j = i+1
            while count[j] == 0:
                j += 1
            return (i + j)/2
Run Code Online (Sandbox Code Playgroud)

令人惊讶的是,短向量的性能差异甚至更大,显然是由于numpy向量的开销:

>>> a = np.random.randint(-128, 128, 10)

>>> %timeit np.median(a)
    The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 20.8 µs per loop

>>> %timeit median_int8(a)
    The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
    1000000 loops, best of 3: 593 ns per loop
Run Code Online (Sandbox Code Playgroud)

这个开销是如此之大,我想知道是否有问题.

Ima*_*ngo 8

简单来说,找到最频繁的数字通常称为模式,它与中位数相似,因为它是平均值 ...在这种情况下np.mean会相当快.除非您的数据存在某些约束或特殊性,否则无法保证模式接近中位数.

如果你仍然想要计算整数列表的模式np.bincount,正如你所提到的,应该足够了(如果numba更快,它应该不会太多):

count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128
Run Code Online (Sandbox Code Playgroud)

注意我已经添加了minlength参数,np.bincount因此它返回了代码中相同的256长度列表.但是,在实践中是完全没有必要的,因为你只想要的argmax,np.bincount(不minlength)会返回一个列表,它的长度是最大数y.

至于numba错误,替换dtype=intdtype=np.int32应该解决问题.int是一个python函数,你nopython在numba头中指定.如果你删除nopython,那么任何一个dtype=intdtype='i'将也会工作(具有相同的效果).