矢量化的基数排序与numpy - 它可以击败np.sort?

dan*_*man 6 python sorting performance numpy vectorization

NumPy的没有尚未有一个基数排序,所以我想知道是否有可能使用一个预先存在numpy的功能来写.到目前为止,我有以下,它确实有效,但比numpy的快速排序慢约10倍.

line profiler输出

测试和基准测试:

a = np.random.randint(0, 1e8, 1e6)
assert(np.all(radix_sort(a) == np.sort(a))) 
%timeit np.sort(a)
%timeit radix_sort(a)
Run Code Online (Sandbox Code Playgroud)

mask_b循环可以至少部分地被矢量化,从掩码中广播&cumsumaxisarg一起使用,但是这最终是一种悲观,可能是由于增加的存储器占用.

如果有人能够看到一种方法来改进我所拥有的东西,我会有兴趣听到,即使它仍然比np.sort... 慢......这更像是一种对知识的好奇心和对numpy技巧的兴趣.

请注意,您可以轻松地实现快速计数排序,但这仅与小整数数据相关.

编辑1:np.arange(n)圈外的帮助一点,但不是很exiciting.

编辑2:cumsum实际上是多余的(哎呀!),但这个简单的版本仅具有性能稍微帮助..

def radix_sort(a):
    bit_len = np.max(a).bit_length()
    n = len(a)
    cached_arange = arange(n)
    idx = np.empty(n, dtype=int) # fully overwritten each iteration
    for mask_b in xrange(bit_len):
        is_one = (a & 2**mask_b).astype(bool)
        n_ones = np.sum(is_one)      
        n_zeros = n-n_ones
        idx[~is_one] = cached_arange[:n_zeros]
        idx[is_one] = cached_arange[:n_ones] + n_zeros
        # next three lines just do: a[idx] = a, but correctly
        new_a = np.empty(n, dtype=a.dtype)
        new_a[idx] = a
        a = new_a
    return a
Run Code Online (Sandbox Code Playgroud)

编辑3:如果您在多个步骤中构造idx,则可以一次循环两个或更多个,而不是循环使用单个位.使用2位有点帮助,我没有尝试过更多:

idx[is_zero] = np.arange(n_zeros)
idx[is_one] = np.arange(n_ones)
idx[is_two] = np.arange(n_twos)
idx[is_three] = np.arange(n_threes)
Run Code Online (Sandbox Code Playgroud)

编辑4和5:对于我正在测试的输入,4位似乎是最好的.此外,你可以idx完全摆脱这一步.现在只有5倍,而不是10倍,慢于np.sort(作为gist提供的源代码):

在此输入图像描述

编辑6:这是上面的一个整理版本,但它也有点.80%的时间花在repeatextract- 如果只有一种方式广播extract:( ...

def radix_sort(a, batch_m_bits=3):
    bit_len = np.max(a).bit_length()
    batch_m = 2**batch_m_bits
    mask = 2**batch_m_bits - 1
    val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxis
    for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division
        a = np.extract((a & mask)[nax, :] == val_set,
                        np.repeat(a[nax, :], batch_m, axis=0))
        val_set <<= batch_m_bits
        mask <<= batch_m_bits
    return a
Run Code Online (Sandbox Code Playgroud)

编辑7和8:实际上,您可以使用as_stridedfrom 来广播提取numpy.lib.stride_tricks,但它似乎没有太大的性能帮助:

在此输入图像描述

最初这对我来说是有意义的,因为它extract会在整个数组batch_m时间内进行迭代,因此CPU请求的高速缓存行总数将与之前相同(只是在它请求每个请求的过程结束时)缓存行batch_m时间).然而,实际情况是,extract不足以巧妙地迭代任意阶梯数组,并且必须在开始之前扩展数组,即无论如何最终都会重复执行.事实上,在查看源代码之后extract,我现在看到我们用这种方法做的最好的事情是:

a = a[np.flatnonzero((a & mask)[nax, :] == val_set) % len(a)]
Run Code Online (Sandbox Code Playgroud)

这比一点慢extract.然而,如果len(a)是两个电源可以代替昂贵的MOD与操作& (len(a) - 1),这并最终被略高于更快的extract版本(目前约4.9x np.sorta=randint(0, 1e8, 2**20).我想我们可以通过零填充使这两个长度的非幂次工作,然后在排序结束时裁剪额外的零...但是这将是一个悲观,除非长度已经接近于二.

rcg*_*ldr 0

您能否将其更改为一次运行 8 位的计数/基数排序?对于 32 位无符号整数,创建一个字节字段出现次数的矩阵 [4][257],使一次读取传递到要排序的数组。矩阵[][0] = 0,矩阵[][1] = 0、...出现的次数。然后将计数转换为索引,其中matrix[][0] = 0,matrix[][1] = # of bytes == 0,matrix[][2] = # of bytes == 0 + # of bytes == 1、.... 不使用最后一个计数,因为这将索引数组的末尾。然后进行 4 次基数排序,在原始数组和输出数组之间来回移动数据。每次工作 16 位需要一个矩阵[2][65537],但只需要 2 遍。C 代码示例:

size_t mIndex[4][257] = {0};            /* index matrix */
size_t i, j, m;
uint32_t u;
uint32_t *pData;                        /* ptr to original array */
uint32_t *pTemp;                        /* ptr to working array */
uint32_t *pSrc;                         /* working ptr */
uint32_t *pDst;                         /* working ptr */
/* n is size of array */
    for(i = 0; i < n; i++){             /* generate histograms */
        u = pData[i];
        for(j = 0; j < 4; j++){
            mIndex[j][1 + (size_t)(u & 0xff)]++; /* note [1 + ... */
            u >>= 8;
        }       
    }
    for(j = 0; j < 4; j++){             /* convert to indices */
        for(i = 1; i < 257; i++){       /* (last count never used) */
            mIndex[j][i] += mIndex[j][i-1]
        }       
    }
    pDst = pTemp;                       /* radix sort */
    pSrc = pData;
    for(j = 0; j < 4; j++){
        for(i = 0; i < count; i++){     /* sort pass */
            u = pSrc[i];
            m = (size_t)(u >> (j<<3)) & 0xff;
        /*  pDst[mIndex[j][m]++] = u;      split into 2 lines */
            pDst[mIndex[j][m]] = u;
            mIndex[j][m]++;
        }
        pTmp = pSrc;                    /* swap ptrs */
        pSrc = pDst;
        pDst = pTmp;
    }
Run Code Online (Sandbox Code Playgroud)