提高通用交换的性能

Oli*_*ock 5 c generics swap memcpy

语境

在 C 中实现处理一系列类型的泛型函数时,void*经常使用。该libc函数qsort()是一个经典的例子。在内部qsort()和许多其他算法中都需要一个swap()函数。

通用交换的一个简单但典型的实现如下所示:

void swap(void* x, void* y, size_t size) {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
}
Run Code Online (Sandbox Code Playgroud)

对于较大的类型,可以使用逐字节交换,否则malloc会很慢,但这里的重点是当这个泛型swap()用于小类型时会发生什么。

更好的通用交换?

事实证明,如果我们匹配一些常见的类型大小(x86_64 上的 4 和 8 字节的 int 和 long),还包括浮点、双精度、指针等,我们可以获得令人惊讶的性能提升:

void swap(void* x, void* y, size_t size) {
  if (size == sizeof(int)) {
    int t      = *((int*)x);
    *((int*)x) = *((int*)y);
    *((int*)y) = t;
  } else if (size == sizeof(long)) {
    long t      = *((long*)x);
    *((long*)x) = *((long*)y);
    *((long*)y) = t;
  } else {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
  }
}

Run Code Online (Sandbox Code Playgroud)

注意:这显然可以改进为 using#if而不是if/else和用于更多类型。

在以下通用quicksort()实现的上下文中,与更标准的memcpy()仅顶部交换相比,上述交换为 10,000,000 随机 int 排序提供了约 2 倍的性能改进。这是在 ubuntu 20.04 上使用 gcc-9 或 clang-10 -O3

这似乎是一个了不起的结果。

  • 这是否违反任何标准?
  • 任何人都可以验证这一点吗?
  • 是什么让这种收益成为可能?它是简单地复制“更广泛的词”还是一些编译器优化/内联在起作用?
  • 如果它确实有效,为什么还没有完成?或者是吗?

注意:我还没有检查生成的汇编代码 - 还没有。

#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

typedef bool (*cmp)(const void*, const void*);

bool cmp_ints_less(const void* a, const void* b) {
  return *(const int*)a < *(const int*)b;
}

bool cmp_ints_greater(const void* a, const void* b) {
  return *(const int*)a > *(const int*)b;
}

bool cmp_floats_less(const void* a, const void* b) {
  return *(const float*)a < *(const float*)b;
}

bool cmp_floats_greater(const void* a, const void* b) {
  return *(const float*)a > *(const float*)b;
}

bool cmp_doubles_less(const void* a, const void* b) {
  return *(const double*)a < *(const double*)b;
}

bool cmp_doubles_greater(const void* a, const void* b) {
  return *(const double*)a > *(const double*)b;
}

bool cmp_strs_less(const void* a, const void* b) {
  return strcmp(*((const char**)a), *((const char**)b)) < 0;
}

bool cmp_strs_greater(const void* a, const void* b) {
  return strcmp(*((const char**)a), *((const char**)b)) > 0;
}

void swap(void* x, void* y, size_t size) {
  if (size == sizeof(int)) {
    int t      = *((int*)x);
    *((int*)x) = *((int*)y);
    *((int*)y) = t;
  } else if (size == sizeof(long)) {
    long t      = *((long*)x);
    *((long*)x) = *((long*)y);
    *((long*)y) = t;
  } else {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
  }
}

void* partition(void* start, void* end, size_t size, cmp predicate) {
  if (start == NULL || end == NULL || start == end) return start;
  char* storage = (char*)start;
  char* last    = (char*)end - size; // used as pivot
  for (char* current = start; current != last; current += size) {
    if (predicate(current, last)) {
      swap(current, storage, size);
      storage += size;
    }
  }
  swap(storage, last, size);
  return storage; // returns position of pivot
}

void quicksort(void* start, void* end, size_t size, cmp predicate) {
  if (start == end) return;
  void* middle = partition(start, end, size, predicate);
  quicksort(start, middle, size, predicate);
  quicksort((char*)middle + size, end, size, predicate);
}

void print(const int* start, int size) {
  for (int i = 0; i < size; ++i) printf("%3d", start[i]);
  printf("\n");
}

void rand_seed() {
  int   seed = 0;
  FILE* fp   = fopen("/dev/urandom", "re");
  if (!fp) {
    fprintf(stderr, "Warning: couldn't open source of randomness, falling back to time(NULL)");
    srand(time(NULL));
    return;
  }
  if (fread(&seed, sizeof(int), 1, fp) < 1) {
    fprintf(stderr, "Warning: couldn't read random seed, falling back to time(NULL)");
    fclose(fp);
    srand(time(NULL));
    return;
  }
  fclose(fp);
  srand(seed); // nice seed for rand()
}

int rand_range(int start, int end) {
  return start + rand() / (RAND_MAX / (end - start + 1) + 1);
}

int main() {
  // int demo
  rand_seed();
#define int_count 20
  int* ints = malloc(int_count * sizeof(int));
  if (!ints) {
    fprintf(stderr, "couldn't allocate memory");
    exit(EXIT_FAILURE);
  }
  for (int i = 0; i < int_count; ++i) ints[i] = rand_range(1, int_count / 2);
  print(ints, int_count);
  quicksort(ints, ints + int_count, sizeof(int), &cmp_ints_less);
  print(ints, int_count);
  free(ints);

  // string demo
  const char* strings[] = {
      "material", "rare",    "fade",      "aloof",  "way",  "torpid",
      "men",      "purring", "abhorrent", "unpack", "zinc", "unsightly",
  };
  const int str_count = sizeof(strings) / sizeof(strings[0]);
  quicksort(strings, strings + str_count, sizeof(char*), &cmp_strs_greater);
  for (int i = 0; i < str_count; ++i) printf("%s\n", strings[i]);

// double demo
#define dbl_count 20
  double doubles[dbl_count];
  for (int i = 0; i < dbl_count; ++i) doubles[i] = rand() / (RAND_MAX / 100.0);
  quicksort(doubles, doubles + dbl_count, sizeof(char*), &cmp_doubles_less);
  for (int i = 0; i < dbl_count; ++i) printf("%20.16f\n", doubles[i]);

  return EXIT_SUCCESS;
}

Run Code Online (Sandbox Code Playgroud)

编辑:

仅供参考 Compiler Explorer 报告了以下非常明显的替代通用程序集swap()

https://godbolt.org/z/GhvsY4

main()那里的样本是:

int main() {
  int two = 2;
  int three = 3;

  swap(&two, &three, sizeof(int));
  swap2(&two, &three, sizeof(int));

  return two - three;
}
Run Code Online (Sandbox Code Playgroud)

全汇编swap2()下方,但值得注意的是,编译器内联swap2(),但没有 swap()这当然containes的进一步调用memcopy。这可能是一些(全部?)差异?

swap2:
        push    rbp
        mov     rbp, rsp
        push    r14
        mov     r14, rdi
        push    r13
        mov     r13, rsi
        push    r12
        push    rbx
        cmp     rdx, 4
        je      .L9
        mov     r12, rdx
        cmp     rdx, 8
        jne     .L7
        mov     rax, QWORD PTR [rdi]
        mov     rdx, QWORD PTR [rsi]
        mov     QWORD PTR [rdi], rdx
        mov     QWORD PTR [rsi], rax
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret
.L7:
        lea     rax, [rdx+15]
        mov     rbx, rsp
        mov     rsi, rdi
        and     rax, -16
        sub     rsp, rax
        mov     rdi, rsp
        call    memcpy
        mov     rdx, r12
        mov     rsi, r13
        mov     rdi, r14
        call    memcpy
        mov     rdx, r12
        mov     rsi, rsp
        mov     rdi, r13
        call    memcpy
        mov     rsp, rbx
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret
.L9:
        mov     eax, DWORD PTR [rdi]
        mov     edx, DWORD PTR [rsi]
        mov     DWORD PTR [rdi], edx
        mov     DWORD PTR [rsi], eax
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret

Run Code Online (Sandbox Code Playgroud)

And*_*nle 2

这是否违反任何标准?

是的。

这是严格的别名违规,可能违反6.3.2.3 指针,第 7 段:“指向对象类型的指针可以转换为指向不同对象类型的指针。如果生成的指针未针对引用类型正确对齐,则行为未定义。...”