最快的代码C/C++,用于选择27个浮点值集合的中位数

chm*_*ike 39 c c++ algorithm optimization

这是众所周知的选择算法.见http://en.wikipedia.org/wiki/Selection_algorithm.

我需要它来找到一组3x3x3体素值的中值.由于体积由十亿个体素组成,算法是递归的,因此最好快一点.通常可以预期值相对接近.

到目前为止,我尝试过的最快的已知算法使用了快速排序分区功能.我想知道是否有更快的.

我已经"发明"了使用两个堆的速度提高了20%,但预计使用散列会更快.在实现这个之前,我想知道是否已经存在闪电战快速解决方案.

我使用浮点数的事实应该无关紧要,因为它们在反转符号位后可以被认为是无符号整数.订单将被保留.

编辑:基准和源代码按照Davy Landman的建议转移到单独的答案中.请参阅下面的chmike答案.

编辑:迄今为止最有效的算法被Boojum引用作为Fast Median和双边过滤论文的链接,现在这个问题的答案就是答案.这种方法的第一个聪明的想法是使用基数排序,第二个是组合共享大量像素的相邻像素的中值搜索.

new*_*cct 29

选择算法是线性时间(O(n)).复杂性方面你不能比线性时间更好,因为读取所有数据需要线性时间.所以你不可能做出更快速复杂的事情.也许你在某些输入上有更快的因素?我怀疑它会带来多大的不同.

C++已经包含线性时间选择算法.为什么不用它呢?

std::vector<YourType>::iterator first = yourContainer.begin();
std::vector<YourType>::iterator last = yourContainer.end();
std::vector<YourType>::iterator middle = first + (last - first) / 2;
std::nth_element(first, middle, last); // can specify comparator as optional 4th arg
YourType median = *middle;
Run Code Online (Sandbox Code Playgroud)

编辑:从技术上讲,这只是一个奇数长度的容器的中位数.对于偶数长度之一,它将获得"上限"中位数.如果你想为中位数,甚至长度的传统定义,你可能需要在两次运行,每进行一次两个"中段"的first + (last - first) / 2first + (last - first) / 2 - 1,然后取它们的平均值或东西.


chm*_*ike 21

编辑:我要道歉.下面的代码是错误的.我有固定的代码,但需要找到一个icc编译器来重做测量.

到目前为止所考虑的算法的基准结果

有关协议和算法的简短描述,请参见下文.第一个值是200个不同序列的平均时间(秒),第二个值是stdDev.

HeapSort     : 2.287 0.2097
QuickSort    : 2.297 0.2713
QuickMedian1 : 0.967 0.3487
HeapMedian1  : 0.858 0.0908
NthElement   : 0.616 0.1866
QuickMedian2 : 1.178 0.4067
HeapMedian2  : 0.597 0.1050
HeapMedian3  : 0.015 0.0049 <-- best
Run Code Online (Sandbox Code Playgroud)

协议:使用从rand()获得的随机位生成27个随机浮点数.连续应用每个算法500万次(包括先前的数组复制),并计算200个随机序列的平均值和stdDev.用icc -S -O3编译的C++代码,运行在带有8GB DDR3的Intel E8400上.

算法:

HeapSort:使用堆排序和选择中间值的完整序列.使用下标访问的朴素实现.

QuickSort:使用快速排序和选择中间值完整到位的序列.使用下标访问的朴素实现.

QuickMedian1:快速选择交换算法.使用下标访问的朴素实现.

HeapMedian1:使用先前交换的平衡堆方法.使用下标访问的朴素实现.

NthElement:使用nth_element STL算法.使用memcpy(vct.data(),rndVal,...)将数据复制到向量中;

QuickMedian2:使用带指针的快速选择算法并复制到两个缓冲区中以避免交换.基于MSalters的提议.

HeapMedian2:我发明的算法的变体,使用带有共享头的双堆.左堆具有最大值作为头,右边具有最小值作为头.初始化为第一个值作为公共头和第一个中值猜测.如果小于head,则将后续值添加到左堆,否则添加到右堆,直到其中一个堆已满.它包含14个值时已满.然后只考虑完整堆.如果它是正确的堆,对于大于头的所有值,弹出头和插入值.忽略所有其他值.如果是左堆,对于小于头的所有值,弹出头并将其插入堆中.忽略所有其他值.当所有值都已进行时,公共头是中值.它使用整数索引到数组.使用指针(64位)的版本似乎慢了近两倍(~1s).

HeapMedian3:与HeapMedian2相同的算法,但已经过优化.它使用unsigned char索引,避免了值交换和其他各种小事情.平均值和stdDev值是在1000个随机序列上计算的.对于nth_element,我使用相同的1000个随机序列测量0.508s和0.15%的stdDev.因此,HeapMedian3比nth_element stl函数快33倍.根据heapSort返回的中值检查每个返回的中值,它们都匹配.我怀疑使用哈希的方法可能会明显更快.

编辑1:该算法可以进一步优化.根据比较结果在左侧或右侧堆中调度元素的第一个阶段不需要堆.简单地将元素附加到两个无序序列就足够了.只要一个序列已满,第一阶段就会停止,这意味着它包含14个元素(包括中值).第二阶段从堆积整个序列开始,然后按照HeapMedian3算法中的描述继续.我会尽快提供新的代码和基准.

编辑2:我实现了优化算法并对其进行了基准测试.但是与heapMedian3相比,没有明显的性能差异.它的平均值甚至略慢.显示的结果已得到确认.可能会有更大的集合.另请注意,我只选择第一个值作为初始中位数猜测.如所建议的那样,我们可以从"重叠"值集中搜索中值来获益.使用中值算法的中值将有助于选择更好的初始中值猜测.


HeapMedian3的源代码

// return the median value in a vector of 27 floats pointed to by a
float heapMedian3( float *a )
{
   float left[14], right[14], median, *p;
   unsigned char nLeft, nRight;

   // pick first value as median candidate
   p = a;
   median = *p++;
   nLeft = nRight = 1;

   for(;;)
   {
       // get next value
       float val = *p++;

       // if value is smaller than median, append to left heap
       if( val < median )
       {
           // move biggest value to the heap top
           unsigned char child = nLeft++, parent = (child - 1) / 2;
           while( parent && val > left[parent] )
           {
               left[child] = left[parent];
               child = parent;
               parent = (parent - 1) / 2;
           }
           left[child] = val;

           // if left heap is full
           if( nLeft == 14 )
           {
               // for each remaining value
               for( unsigned char nVal = 27 - (p - a); nVal; --nVal )
               {
                   // get next value
                   val = *p++;

                   // if value is to be inserted in the left heap
                   if( val < median )
                   {
                       child = left[2] > left[1] ? 2 : 1;
                       if( val >= left[child] )
                           median = val;
                       else
                       {
                           median = left[child];
                           parent = child;
                           child = parent*2 + 1;
                           while( child < 14 )
                           {
                               if( child < 13 && left[child+1] > left[child] )
                                   ++child;
                               if( val >= left[child] )
                                   break;
                               left[parent] = left[child];
                               parent = child;
                               child = parent*2 + 1;
                           }
                           left[parent] = val;
                       }
                   }
               }
               return median;
           }
       }

       // else append to right heap
       else
       {
           // move smallest value to the heap top
           unsigned char child = nRight++, parent = (child - 1) / 2;
           while( parent && val < right[parent] )
           {
               right[child] = right[parent];
               child = parent;
               parent = (parent - 1) / 2;
           }
           right[child] = val;

           // if right heap is full
           if( nRight == 14 )
           {
               // for each remaining value
               for( unsigned char nVal = 27 - (p - a); nVal; --nVal )
               {
                   // get next value
                   val = *p++;

                   // if value is to be inserted in the right heap
                   if( val > median )
                   {
                       child = right[2] < right[1] ? 2 : 1;
                       if( val <= right[child] )
                           median = val;
                       else
                       {
                           median = right[child];
                           parent = child;
                           child = parent*2 + 1;
                           while( child < 14 )
                           {
                               if( child < 13 && right[child+1] < right[child] )
                                   ++child;
                               if( val <= right[child] )
                                   break;
                               right[parent] = right[child];
                               parent = child;
                               child = parent*2 + 1;
                           }
                           right[parent] = val;
                       }
                   }
               }
               return median;
           }
       }
   }
} 
Run Code Online (Sandbox Code Playgroud)


Boo*_*jum 14

由于听起来您正在对大量体数据执行中值滤波,因此您可能需要查看SIGGRAPH 2006 快速中值和双边滤波文章.该文章涉及2D图像处理,但您可能是能够适应3D体积的算法.如果不出意外,它可能会给你一些关于如何退一步并从略微不同的角度看待问题的想法.


ste*_*han 13

这个问题不容易回答,原因很简单,一个算法相对于另一个算法的性能与算法本身的编译器/处理器/数据结构组合一样多,因为你肯定知道

因此,尝试其中几个的方法似乎已经足够了.是的,快速排序应该非常快.如果您还没有这样做,您可能想尝试insertionsort,它通常在小数据集上表现更好.这就是说,只需要解决一个能够快速完成工作的排序算法.选择"正确"的算法通常不会快10倍.

为了获得大幅度的加速,更好的方法是使用更多的结构.一些过去对我有用的大型问题的想法:

  • 你可以在创建体素时有效地预先计算并存储28而不是27个浮点数吗?

  • 一个近似解决方案是否足够好?如果是这样,只需看看9个值的中位数,因为"通常可以预期值相对接近".或者只要值相对接近,您就可以用平均值替换它.

  • 你真的需要所有数十亿体素的中位数吗?也许你有一个简单的测试,你是否需要中位数,然后只能计算相关的子集.

  • 如果没有其他帮助:查看编译器生成的asm代码.您可能能够编写速度更快的asm代码(例如,通过使用寄存器执行所有计算).

编辑:为了它的价值,我附上了下面评论中提到的(部分)插入代码(完全未经测试).如果numbers[]是一个大小数组N,并且您希望P在数组的开头排序最小的浮点数,请调用partial_insertionsort<N, P, float>(numbers);.因此,如果你打电话partial_insertionsort<27, 13, float>(numbers);,numbers[13]将包含中位数.要获得额外的速度,你也必须展开while循环.如上所述,为了获得非常快的速度,您必须使用您对数据的了解(例如,数据是否已经部分排序?您是否知道数据分布的属性?我想,您会得到漂移).

template <long i> class Tag{};

template<long i, long N, long P, typename T>
inline void partial_insertionsort_for(T a[], Tag<N>, Tag<i>)
{   long j = i <= P+1 ? i : P+1;  // partial sort
    T temp = a[i];
    a[i] = a[j];       // compiler should optimize this away where possible
    while(temp < a[j - 1] && j > 0)
    { a[j] = a[j - 1];
      j--;}
    a[j] = temp;
    partial_insertionsort_for<i+1,N,P,T>(a,Tag<N>(),Tag<i+1>());}

template<long i, long N, long P, typename T>
inline void partial_insertionsort_for(T a[], Tag<N>, Tag<N>){}

template <long N, long P, typename T>
inline void partial_insertionsort(T a[])
 {partial_insertionsort_for<0,N,P,T>(a, Tag<N>(), Tag<0>());}
Run Code Online (Sandbox Code Playgroud)


MSa*_*ers 6

在第一次尝试中使用的最可能的算法就是nth_element; 它几乎可以直接给你你想要的东西.只要问第14个元素.

在第二次尝试时,目标是利用固定的数据大小.你根本不需要为你的算法分配任何内存.因此,将您的体素值复制到预先分配的27个元素的数组中.选择一个轴,然后将其复制到53个元素阵列的中间.将剩余值复制到数据透视的任一侧.在这里你保留两个指针(float* left = base+25, *right=base+27).现在有三种可能性:左侧较大,右侧较大,或两者都有12个元素.最后一个案例是微不足道的; 你的支点是中位数.否则,在左侧或右侧调用nth_element.Nth的确切值取决于多少或多于枢轴的值.例如,如果除法是12/14,则需要比枢轴大的最小元素,因此Nth = 0,如果除法是14/12,则需要最大元素小于枢轴,因此Nth = 13.最糟糕的情况是26/0和0/26,当你的枢轴是一个极端,但这些只发生在所有情况的2/27.

第三个改进(或第一个改进,如果你必须使用C并且没有nth_element)完全替换nth_element.您仍然拥有53个元素数组,但这次您直接从体素值填充它(将临时副本保存为a float[27]).第一次迭代中的枢轴只是体素[0] [0] [0].对于后续迭代,您使用第二个预分配float[53](如果两个大小相同则更容易)并在两者之间复制浮点数.这里的基本迭代步骤仍然是:将枢轴复制到中间,将其余部分分类到左侧和右侧.在每个步骤结束时,您将知道中位数是小于还是大于当前枢轴,因此您可以丢弃大于或小于该枢轴的浮动.每次迭代,这消除了1到12个元素,平均剩余25%.

如果您仍然需要更高的速度,最后一次迭代是基于大多数体素显着重叠的观察结果.您为每个3x3x1切片预先计算中值.然后,当您需要3x3x3体素立方体的初始枢轴时,您将获取三者的中位数.您事先知道有9个体素较小,9个体素大于中位数(4 + 4 + 1).因此,在第一个枢转步骤之后,最坏的情况是9/17和17/9分裂.所以,你只需要在float [17]中找到第4或第13个元素,而不是在float [26]中找到第12或第14个元素.


背景:使用左右指针首先复制一个枢轴,然后将浮动[N]的其余部分复制到一个浮点[2N-1]的想法是你在所有元素周围填充一个浮点[N]子阵列小于左侧的枢轴(较低的索引)和较高的右侧(较高的索引).现在,如果你想要Mth元素,你可能会发现自己很幸运并且M-1元素比枢轴小,在这种情况下,pivot是你需要的元素.如果有多于(M-1)个元素小于枢轴,则第M个元素就在其中,因此您可以丢弃枢轴和任何大于枢轴的东西,并在所有较低值中丢弃第M个元素的seacrh.如果小于(M-1)个元素小于数据透视表,则表示您要查找高于数据透视表的值.所以,你将丢弃枢轴和任何小于它的东西.让小于枢轴的元素数量,即枢轴左侧的元素数量为L.在下一次迭代中,您需要(NL-1)浮点数大于枢轴的(ML-1)元素.

这种nth_element算法是相当有效的,因为大部分工作都花在两个小数组之间复制浮点数,两个小数组都在缓存中,并且因为你的状态大部分时间用3个指针表示(源指针,左目标指针) ,右目的地指针).

要显示基本代码:

float in[27], out[53];
float pivot = out[26] = in[0];     // pivot
float* left = out+25, right = out+27
for(int i = 1; i != 27; ++1)
if((in[i]<pivot)) *left-- = in[i] else *right++ = in[i];
// Post-condition: The range (left+1, right) is initialized.
// There are 25-(left-out) floats <pivot and (right-out)-27 floats >pivot
Run Code Online (Sandbox Code Playgroud)