非常频繁地调用std :: nth_element()函数

kar*_*ten 15 c++ sorting performance inline nth-element

我没有在任何地方找到这个特定主题......

我在23个整数的std :: vector中的不同数据上调用nth_element()算法,每秒大约400,000次,更精确的"无符号短"值.

我想提高计算速度,这个特定的调用需要很大一部分CPU时间.现在我注意到,与std :: sort()一样,即使具有最高优化级别和NDEBUG模式(Linux Clang编译器),nth_element函数在探查器中也是可见的,因此比较是内联的而不是函数调用本身.好吧,更多的preise:不是nth_element()但是std :: __ introselect()是可见的.

由于数据的大小很小,我尝试使用二次排序函数PIKSORT,当数据大小小于20个元素时,它通常比调用std :: sort更快,可能是因为函数将是内联的.

template <class CONTAINER>
inline void piksort(CONTAINER& arr)  // indeed this is "insertion sort"
{
    typename CONTAINER::value_type a;

    const int n = (int)arr.size();
    for (int j = 1; j<n; ++j) {
        a = arr[j];
        int i = j;
        while (i > 0 && a < arr[i - 1]) {
            arr[i] = arr[i - 1];
            i--;
        }
        arr[i] = a;
    }
}
Run Code Online (Sandbox Code Playgroud)

然而,这比在这种情况下使用nth_element慢.

此外,使用统计方法是不合适的,比std :: nth_element更快

最后,由于值在0到约20000的范围内,因此直方图方法看起来不合适.

我的问题:有没有人知道一个简单的解决方案?我想我可能不是唯一一个经常调用std :: sort或nth_element的人.

Mor*_*enn 13

您提到数组的大小始终是已知的23.而且,使用的类型是unsigned short.在这种情况下,您可能会尝试使用大小的排序网络23; 因为您的类型是unsigned short,使用排序网络对整个数组进行排序可能比部分排序更快std::nth_element.下面是一个非常直接的C++ 14实现的大小排序网络,23具有118比较交换单元,如使用对称和进化搜索最小化排序网络所述:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}
Run Code Online (Sandbox Code Playgroud)

swap_if效用函数比较两个参数xy与谓词compare和交换他们,如果compare(y, x).我的示例使用了一个泛型swap_if函数,但如果您知道将要比较unsigned short值,则可以使用优化版本operator<(如果您的编译器识别并优化了比较交换,您可能不需要这样的函数,但不幸的是,并非所有编译器都是如此)这样做 - 我正在使用g ++ 5.2,-O3我仍然需要以下函数来提高性能):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}
Run Code Online (Sandbox Code Playgroud)

现在,为了确保它确实更快,我决定std::nth_element在需要时仅对前10个元素进行部分排序,然后使用排序网络对整个23个元素进行排序(使用不同的混洗数组进行1000000次).这是我得到的:

std::nth_element    1158ms
network_sort23      487ms
Run Code Online (Sandbox Code Playgroud)

也就是说,我的电脑已经运行了一段时间而且速度有点慢,但性能差异很大.我相信当我重新启动计算机时,这种差异将保持不变.我可以稍后再试,让你知道.

关于如何生成这些时间,我使用了来自cpp-sort库此基准测试的修改版本.原始的分拣网络和功能也来自那里,所以你可以确定它们已被多次测试:)swap_if

编辑:这是我已经重新启动计算机的结果.该network_sort23版本仍然比std::nth_element以下快两倍:

std::nth_element    369ms
network_sort23      154ms
Run Code Online (Sandbox Code Playgroud)

EDIT²:如果你需要中位数,你可以简单地删除不需要的比较交换单元来计算将在第11位的最终值.随后得到的大小为23的中位数查找网络使用与前一个不同的23号大小的排序网络,它产生稍微好一点的结果:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);
Run Code Online (Sandbox Code Playgroud)

可能有更聪明的方法来产生中位数发现网络,但我认为没有对该主题进行过广泛的研究.因此,它可能是您现在可以使用的最佳方法.结果并不令人敬畏,但它仍然使用104个比较交换单元而不是118个.

  • 这是一种问题和答案,通过一堆"做我的家庭作业"或"调试我的程序"的帖子进行疏浚,这些帖子今天吸引了大量有价值的东西. (3认同)

stg*_*lov 5

大概的概念

查看std::nth_elementMSVC2013中的源代码,似乎N <= 32的情况是通过插入排序解决的。这意味着 STL 实现者意识到,尽管该大小的渐近性更好,但进​​行随机分区会更慢。

提高性能的方法之一是优化排序算法。@Morwenn 的答案展示了如何使用排序网络对 23 个元素进行排序,这被认为是对小型常量大小数组进行排序的最快方法之一。我将研究另一种方法,即在不使用排序算法的情况下计算中位数。事实上,我根本不会排列输入数组。

由于我们讨论的是小数组,因此我们需要以最简单的方式实现一些O(N^2)算法。理想情况下,它应该根本没有分支,或者只有可良好预测的分支。此外,该算法的简单结构可以让我们对其进行矢量化,从而进一步提高其性能。

算法

我决定采用计数方法,此处使用该方法来加速小型线性搜索。首先,假设所有元素都不同。选择数组的任意元素:元素数量小于定义其在排序数组中的位置。我们可以迭代所有元素,并为每个元素计算小于它的元素数量。如果排序后的索引具有期望的值,我们可以停止算法。

Unfortunately, there may be equal elements in general case. We'll have to make our algorithm significantly slower and more complex to handle them. Instead of calculating the unique sorted index of an element, we can calculate interval of possible sorted indices for it. For any element, it is enough to count number of elements less than it (L) and number of elements equal to it (E), then sorted index fits range [L, L+R). If this interval contains desired sorted index (i.e. N/2), then we can stop the algorithm and return the considered element.

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: /sf/answers/1196687411/
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}
Run Code Online (Sandbox Code Playgroud)

Vectorization

The constructed algorithm has only one branch, which is rather predictable: it fails in all cases, except for the only case when we stop the algorithm. The algorithm is easy to vectorize using 8 elements per SSE register. Since we'll have to access some elements after the last one, I'll assume that the input array is padded with max=2^15-1 values up to 24 or 32 elements.

The first way is to vectorize inner loop by j. In this case inner loop would be executed only 3 times, but two 8-wide reductions must be done after it is finished. They eat more time than the inner loop itself. As a result, such a vectorization is not very efficient.

The second way is to vectorize outer loop by i. In this case we process 8 elements x = arr[i] at once. For each pack, we compare it with each element arr[j] in inner loop. After the inner loop we perform vectorized range check for the whole pack of 8 elements. If any of them succeeds, we determine exact number with simple scalar code (it eats little time anyway).

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}
Run Code Online (Sandbox Code Playgroud)

Here we see _mm_set1_epi16 intrinsic in the innermost loop. GCC seems to have some performance issues with it. Anyway, it is eating time on each innermost iteration, which can be reduced if we process 8 elements at once in the innermost loop too. In such case we can do one vectorized load and 14 unpack instructions to obtain vAll for eight elements. Also, we'll have to write compare-and-count code for eight elements in loop body, so it acts as 8x unrolling too. The resulting code is the fastest one, a link to it can be found below.

Comparison

I have benchmarked various solutions on Ivy Bridge 3.4 Ghz processor. Below you can see total computation time for 2^23 ~= 8M calls in seconds (the first number). Second number is checksum of results.

Results on MSVC 2013 x64 (/O2):

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j
Run Code Online (Sandbox Code Playgroud)

Results on MinGW GCC 4.8.3 x64 (-O3 -msse4):

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)
Run Code Online (Sandbox Code Playgroud)

As you see, the proposed vectorized algorithm for 23 16-bit elements is a bit faster than sorting-based approach (BTW, on an older CPU I see only 5% time difference). If you can guarantee that all elements are different, you can simplify the algorithm, making it even faster.

所有算法的完整代码可以在这里找到,包括所有测试代码。

  • @stgatilov 完整源代码的结果: clang++ -msse4 -std=c++11 -O3 sorting.cpp -o sorting clang version 3.5.0 Linux savitri 3.16.7-24-desktop #1 SMP PREEMPT CPU:Intel Core i5仅 -2410M CPU @ 2.30GHz memcpy:0.061 std::nth_element:2.575 (974340096) 冒泡排序:5.202 (974340096) 选择排序:2.859 (974340096) 网络排序:0.777 (974340096) 平凡计数:1.440 (9) 74340096) 向量化计数: 0.723 (974340096) 向量化计数 (n&lt;=24): 0.705 (974340096) 向量化计数 (T): 0.707 (974340096) 向量化计数(两者): 0.440 (974340096) 所以两者向量化方法是最快的。 (2认同)

sp2*_*nny 2

我发现这个问题很有趣,所以我尝试了所有我能想到的算法。
结果如下:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms
Run Code Online (Sandbox Code Playgroud)

我想我们可以得出结论,您已经使用了最快的算法。男孩,我错了。但是,如果您可以接受近似答案,可能有更快的方法,例如中位数的中值
如果您有兴趣,来源在这里


归档时间:

查看次数:

1106 次

最近记录:

9 年,6 月 前