删除 Rust 循环中的边界检查以尝试达到最佳编译器输出

mcm*_*yer 10 loops rust bounds-check-elimination

为了确定我是否可以/应该使用 Rust 而不是默认的 C/C++,我正在研究各种边缘情况,主要考虑到这个问题:在 0.1% 的情况下,它确实很重要,我总能得到编译器输出与 gcc 一样好(具有适当的优化标志)?答案很可能是否定的,但让我们看看......

Reddit上有一个相当特殊的示例,研究无分支排序算法的子例程的编译器输出。

这是基准 C 代码:

#include <stdint.h>
#include <stdlib.h>
int32_t* foo(int32_t* elements, int32_t* buffer, int32_t pivot)
{
    size_t buffer_index = 0;

    for (size_t i = 0; i < 64; ++i) {
        buffer[buffer_index] = (int32_t)i;
        buffer_index += (size_t)(elements[i] < pivot);
    }
}
Run Code Online (Sandbox Code Playgroud)

这是带有编译器输出的godbolt 链接

Rust 的第一次尝试如下所示:

pub fn foo0(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        buffer[buffer_index] = i as i32;
        buffer_index += (elements[i] < pivot) as usize; 
    }
}
Run Code Online (Sandbox Code Playgroud)

正在进行相当多的边界检查,请参阅godbolt

下一次尝试消除第一次边界检查:

pub unsafe fn foo1(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            buffer[buffer_index] = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

这样好一点(请参阅与上面相同的 godbolt 链接)。

最后,让我们尝试完全删除边界检查:

use std::ptr;

pub unsafe fn foo2(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    unsafe {
        for i in 0..buffer.len() {
            ptr::replace(&mut buffer[buffer_index], i as i32);
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

这会产生与 相同的输出foo1,因此ptr::replace仍然执行边界检查。对于这些unsafe操作,我确实超出了我的能力范围。这引出了我的两个问题:

  • 如何消除边界检查?
  • 分析这样的边缘情况是否有意义?或者,如果提供整个算法而不是其中的一小部分,Rust 编译器会看穿这一切吗?

关于最后一点,我很好奇,总的来说,Rust 是否可以被屠宰到“字面上的”程度,即接近金属,就像 C 一样。经验丰富的 Rust 程序员可能会对这种调查感到畏缩,但事实是……

Ang*_*ros 2

您可以使用老式指针算术来实现这一点。

const N: usize = 64;
pub fn foo2(elements: &Vec<i32>, mut buffer: [i32; N], pivot: i32) -> () {
    assert!(elements.len() >= N);
    let elements = &elements[..N];
    let mut buff_ptr = buffer.as_mut_ptr();
    for (i, &elem) in elements.iter().enumerate(){
        unsafe{
            // SAFETY: We increase ptr strictly less or N times
            *buff_ptr = i as i32;
            if elem < pivot{
                buff_ptr = buff_ptr.add(1);
            }
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

该版本编译为:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi

        // Loop goes here
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        // Conditional branch to the loop beginning
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2
Run Code Online (Sandbox Code Playgroud)

如您所见,循环展开,单分支是循环迭代跳转。

然而,令我惊讶的是,这个函数并没有被消除,因为它没有任何效果:它应该被编译成简单的 noop。内联后可能会变成这样。

另外,我想说,更改 &mut 的参数不会更改代码:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2
Run Code Online (Sandbox Code Playgroud)

因此,不幸的是,rustc 可能会发出该函数接受缓冲区参数作为 LLVM IR 中的指针。