使用SIMD/SSE进行水平运行差异和条件更新?

Kar*_*ner 2 c c++ sse simd vectorization

我想矢量化以下操作:

V[i+1] = max(V[i] - c, V[i+1]) for i=1 to n-1 (V[0] = 0)
Run Code Online (Sandbox Code Playgroud)

相应的天真伪代码是:

for (i=0; i < n; i++) {
  if (V[i]-c > V[i+1]) V[i+1] = V[i]-c
}
Run Code Online (Sandbox Code Playgroud)

哪些SIMD说明有用?

Z b*_*son 5

这可以通过SIMD完成.该解决方案类似于SIMD前缀和的解决方案.

在SIMD寄存器中,迭代次数如下O(Log2(simd_width)).每次迭代都需要:一次移位,一次减法和一次最大值.例如,对于SSE,它需要Log2(4) = 2迭代.您可以在以下四个元素上应用您的函数:

__m128i foo_SSE(__m128i x, int c) {
    __m128i t, c1, c2;
    c1 = _mm_set1_epi32(c);
    c2 = _mm_set1_epi32(2*c);

    t = _mm_slli_si128(x, 4);
    t = _mm_sub_epi32(t, c1);
    x = _mm_max_epi32(x, t);

    t = _mm_slli_si128(x, 8);
    t = _mm_sub_epi32(t, c2);
    x = _mm_max_epi32(x, t);
    return x;
}
Run Code Online (Sandbox Code Playgroud)

获得SIMD寄存器的结果后,需要将"进位"应用于下一个寄存器.例如,假设您有一个包含a八个元素的数组.你加载SSE寄存器x1,x2像这样

__m128i x1 = _mm_loadu_si128((__m128i*)&a[0]);
__m128i x2 = _mm_loadu_si128((__m128i*)&a[4]);
Run Code Online (Sandbox Code Playgroud)

然后将您的函数应用于您将要执行的所有八个元素

__m128i t, s;
s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);

x1 = foo_SSE(x1,c);
x2 = foo_SSE(x2,c);
t = _mm_shuffle_epi32(x1, 0xff);
t = _mm_sub_epi32(t,s);
x2 = _mm_max_epi32(x2,t);
Run Code Online (Sandbox Code Playgroud)

需要注意的是c1,c2s为循环中的所有常量所以他们只需要计算一次.

通常,您可以a使用SSE(具有n4的倍数)将您的函数应用于unsigned int数组:

void fill_SSE(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);
    for(int i=0; i<n/4; i++) {
        __m128i x = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i out = foo_SSE(x, c);
        out = _mm_max_epi32(out,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], out);
        offset = _mm_shuffle_epi32(out, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}
Run Code Online (Sandbox Code Playgroud)

我继续介绍了这个SSE代码.它比串行版快2.5倍.

这种方法的另一个主要优点除了log2(simd_width)可以打破依赖链,以便多个SIMD操作可以同时进行(使用多个端口)而不是等待先前的结果.串行代码是延迟限制.

当前代码适用于无符号整数,但您可以将其概括为有符号整数和浮点数.

这是我用来测试它的一般代码.在实现SSE版本之前,我创建了一堆抽象的SIMD函数来模拟SIMD硬件.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <x86intrin.h>
#include <omp.h>

__m128i foo_SSE(__m128i x, int c) {
    __m128i t, c1, c2;
    c1 = _mm_set1_epi32(c);
    c2 = _mm_set1_epi32(2*c);

    t = _mm_slli_si128(x, 4);
    t = _mm_sub_epi32(t, c1);
    x = _mm_max_epi32(x, t);

    t = _mm_slli_si128(x, 8);
    t = _mm_sub_epi32(t, c2);
    x = _mm_max_epi32(x, t);
    return x;
}

void foo(int *a, int n, int c) {
    for(int i=0; i<n-1; i++) {
        if(a[i]-c > a[i+1]) a[i+1] = a[i]-c;
    }
}

void broad(int *a, int n, int k) {
    for(int i=0; i<n; i++) a[i] = k;
}

void shiftr(int *a, int *b, int n, int m) {
    int i;
    for(i=0; i<m; i++) b[i] = a[i];
    for(; i<n; i++) b[i] = a[i-m];
}

/*
void shiftr(int *a, int *b, int n, int m) {
    int i;
    for(i=0; i<m; i++) b[i] = 0;
    for(; i<n; i++) b[i] = a[i-m];
}
*/

void sub(int *a, int n, int c) {
    for(int i=0; i<n; i++) a[i] -= c;
}


void max(int *a, int *b, int n) {
    for(int i=0; i<n; i++) if(b[i]>a[i]) a[i] = b[i];
}

void step(int *a, int n, int c) {
    for(int i=0; i<n; i++) {
        a[i] -= (i+1)*c;
    }
}

void foo2(int *a, int n, int c) {
    int b[n];
    for(int m=1; m<n; m*=2) {
        shiftr(a,b,n,m);
        sub(b, n, m*c);
        max(a,b,n);
        //printf("n %d, m %d; ", n,m ); for(int i=0; i<n; i++) printf("%2d ", b[i]); puts("");
    }
}

void fill(int *a, int n, int w, int c) {
    int b[w], offset[w];
    broad(offset, w, -1000);
    for(int i=0; i<n/w; i++) {
        for(int m=1; m<w; m*=2) {
            shiftr(&a[w*i],b,w,m);
            sub(b, w, m*c);
            max(&a[w*i],b,w);
        }
        max(&a[w*i],offset,w);
        broad(offset,w,a[w*i+w-1]);
        step(offset, w, c);
    }
}


void fill_SSE(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);
    for(int i=0; i<n/4; i++) {
        __m128i x = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i out = foo_SSE(x, c);
        out = _mm_max_epi32(out,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], out);
        offset = _mm_shuffle_epi32(out, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

void fill_SSEv2(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/4; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i t1;

        t1 = _mm_slli_si128(x1, 4);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_slli_si128(x1, 8);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

void fill_SSEv3(int *a, int n, int c) {
    __m128i offset = _mm_setzero_si128();
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/8; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[8*i]);
        __m128i x2 = _mm_loadu_si128((__m128i*)&a[8*i+4]);
        __m128i t1, t2;

        t1 = _mm_slli_si128(x1, 4);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_slli_si128(x1, 8);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        t2 = _mm_slli_si128(x2, 4);
        t2 = _mm_sub_epi32 (t2, c1);
        x2 = _mm_max_epi32 (x2, t2);

        t2 = _mm_slli_si128(x2, 8);
        t2 = _mm_sub_epi32 (t2, c2);
        x2 = _mm_max_epi32 (x2, t2);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[8*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);

        x2 = _mm_max_epi32(x2,offset);
        _mm_storeu_si128((__m128i*)&a[8*i+4], x2);
        offset = _mm_shuffle_epi32(x2, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}

int main(void) {
    int n = 8, a[n], a1[n], a2[n];
    for(int i=0; i<n; i++) a[i] = i;

    /*
    a[0] = 1, a[1] = 0;
    a[2] = 2, a[3] = 0;
    a[4] = 3, a[5] = 13;
    a[6] = 4, a[7] = 0;
    */


    a[0] = 5, a[1] = 6;
    a[2] = 7, a[3] = 8;
    a[4] = 1, a[5] = 2;
    a[6] = 3, a[7] = 4;

    for(int i=0; i<n; i++) printf("%2d ", a[i]); puts("");
    for(int i=0; i<n; i++) a1[i] = a[i], a2[i] = a[i];

    int c = 1;
    foo(a1,n,c);
    foo2(a2,n,c);
    for(int i=0; i<n; i++) printf("%2d ", a1[i]); puts("");
    for(int i=0; i<n; i++) printf("%2d ", a2[i]); puts("");


    __m128i x1 = _mm_loadu_si128((__m128i*)&a[0]);
    __m128i x2 = _mm_loadu_si128((__m128i*)&a[4]);
    __m128i t, s;
    s = _mm_setr_epi32(c, 2*c, 3*c, 4*c);

    x1 = foo_SSE(x1,c);
    x2 = foo_SSE(x2,c);
    t = _mm_shuffle_epi32(x1, 0xff);
    t = _mm_sub_epi32(t,s);
    x2 = _mm_max_epi32(x2,t);

    int a3[8];
    _mm_storeu_si128((__m128i*)&a3[0], x1);
    _mm_storeu_si128((__m128i*)&a3[4], x2);
    for(int i=0; i<8; i++) printf("%2d ", a3[i]); puts("");

    int w = 8;
    n = w*1000;
    int f1[n], f2[n];
    for(int i=0; i<n; i++) f1[i] = rand()%1000;

    for(int i=0; i<n; i++) f2[i] = f1[i];
    //for(int i=0; i<n; i++) printf("%2d ", f1[i]); puts("");
    foo(f1, n, c);
    //fill(f2, n, 8, c);
    fill_SSEv3(f2, n, c);
    printf("%d\n", memcmp(f1,f2,sizeof(int)*n));
    for(int i=0; i<n; i++) {
        //    if(f1[i] != f2[i]) printf("%d\n", i);
    }
    //for(int i=0; i<n; i++) printf("%2d ", f1[i]); puts("");
    //for(int i=0; i<n; i++) printf("%2d ", f2[i]); puts("");

    int r = 200000;
    double dtime;
    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) fill_SSEv2(f2, n, c);
    //for(int i=0; i<r; i++) foo(f1, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);

    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) fill_SSEv3(f2, n, c);
    //for(int i=0; i<r; i++) foo(f1, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);

    dtime = -omp_get_wtime();
    for(int i=0; i<r; i++) foo(f1, n, c);
    //for(int i=0; i<r; i++) fill_SSEv2(f2, n, c);
    dtime += omp_get_wtime();
    printf("time %f\n", dtime);
}
Run Code Online (Sandbox Code Playgroud)

基于Paul RI的评论能够修复我的函数以使用有符号整数.但是,它需要c>=0.我相信它可以修复工作c<0.

void fill_SSEv2(int *a, int n, int c) {
    __m128i offset = _mm_set1_epi32(0xf0000000);
    __m128i s = _mm_setr_epi32(1*c, 2*c, 3*c, 4*c);
    __m128i c1 = _mm_set1_epi32(1*c);
    __m128i c2 = _mm_set1_epi32(2*c);
    for(int i=0; i<n/4; i++) {
        __m128i x1 = _mm_loadu_si128((__m128i*)&a[4*i]);
        __m128i t1;

        t1 = _mm_shuffle_epi32(x1, 0x90);
        t1 = _mm_sub_epi32 (t1, c1);
        x1 = _mm_max_epi32 (x1, t1);

        t1 = _mm_shuffle_epi32(x1, 0x44);
        t1 = _mm_sub_epi32 (t1, c2);
        x1 = _mm_max_epi32 (x1, t1);

        x1 = _mm_max_epi32(x1,offset);
        _mm_storeu_si128((__m128i*)&a[4*i], x1);
        offset = _mm_shuffle_epi32(x1, 0xff);
        offset = _mm_sub_epi32(offset,s);
    }
}
Run Code Online (Sandbox Code Playgroud)

这种方法现在应该很容易扩展到浮点数.

  • @ PecCordes,`c1`和`c2`是循环中的常量,因此不必每次迭代都重新计算它们`foo_SSE`在`fill_SSE`中被调用.我只将它们放在`foo_SSE`中,因为它是第一个使代码工作的通道.如果我对代码进行基准测试,我会将它们拉出来(或者使它们成为const或者说服编译器只执行一次它们所花费的任何东西).此外,如果我对代码进行基准测试,我也会在`fill_SSE`中展开循环. (2认同)
  • 凉.我必须记住,对于像这样的串行依赖,有一种有用的技术.非常好的答案.在我最后的评论,我已经忘记了当初的问题的细节,以及刚刚看着你的积木式功能,显然忘记了C2广播可以和将被吊出循环. (2认同)
  • @PaulR,关于仿射循环变换我发现这个http://www.sciencedirect.com/science/article/pii/S0010465501002326什么是计算机物理?这听起来像我应该做的,因为我是一个现在更多关于计算机的物理学家. (2认同)
  • @KarlForner,你应该声明你对字节感兴趣.这产生了很大的不同,特别是因为log2(simd_width)对于带有字节的SSE是log2(16)= 4.见[这个答案](http://stackoverflow.com/a/10589306/2542702).事实上,你只看到25%的收益可能是许多其他rason.仅通过查看适合L1缓存的数组,我获得了2.5倍的增益.随着大小的增加,您的功能将变为内存带宽. (2认同)