平方和的编译器优化

mer*_*ian 9 c++ optimization assembly rust

这是我觉得有趣的事情:

pub fn sum_of_squares(n: i32) -> i32 {
    let mut sum = 0;
    for i in 1..n+1 {
        sum += i*i;
    }
    sum
}
Run Code Online (Sandbox Code Playgroud)

这是 Rust 中平方和的简单实现。rustc 1.65.0这是带有with的汇编代码-O3

        lea     ecx, [rdi + 1]
        xor     eax, eax
        cmp     ecx, 2
        jl      .LBB0_2
        lea     eax, [rdi - 1]
        lea     ecx, [rdi - 2]
        imul    rcx, rax
        lea     eax, [rdi - 3]
        imul    rax, rcx
        shr     rax
        imul    eax, eax, 1431655766
        shr     rcx
        lea     ecx, [rcx + 4*rcx]
        add     ecx, eax
        lea     eax, [rcx + 4*rdi]
        add     eax, -3
.LBB0_2:
        ret
Run Code Online (Sandbox Code Playgroud)

我原以为它会使用平方和的公式,但事实并非如此。1431655766它使用了一个我根本不理解的神奇数字。

然后我想看看 clang 和 gcc 在 C++ 中对相同的函数做了什么

        test    edi, edi
        jle     .L8
        lea     eax, [rdi-1]
        cmp     eax, 17
        jbe     .L9
        mov     edx, edi
        movdqa  xmm3, XMMWORD PTR .LC0[rip]
        xor     eax, eax
        pxor    xmm1, xmm1
        movdqa  xmm4, XMMWORD PTR .LC1[rip]
        shr     edx, 2
.L4:
        movdqa  xmm0, xmm3
        add     eax, 1
        paddd   xmm3, xmm4
        movdqa  xmm2, xmm0
        pmuludq xmm2, xmm0
        psrlq   xmm0, 32
        pmuludq xmm0, xmm0
        pshufd  xmm2, xmm2, 8
        pshufd  xmm0, xmm0, 8
        punpckldq       xmm2, xmm0
        paddd   xmm1, xmm2
        cmp     eax, edx
        jne     .L4
        movdqa  xmm0, xmm1
        mov     eax, edi
        psrldq  xmm0, 8
        and     eax, -4
        paddd   xmm1, xmm0
        add     eax, 1
        movdqa  xmm0, xmm1
        psrldq  xmm0, 4
        paddd   xmm1, xmm0
        movd    edx, xmm1
        test    dil, 3
        je      .L1
.L7:
        mov     ecx, eax
        imul    ecx, eax
        add     eax, 1
        add     edx, ecx
        cmp     edi, eax
        jge     .L7
.L1:
        mov     eax, edx
        ret
.L8:
        xor     edx, edx
        mov     eax, edx
        ret
.L9:
        mov     eax, 1
        xor     edx, edx
        jmp     .L7
.LC0:
        .long   1
        .long   2
        .long   3
        .long   4
.LC1:
        .long   4
        .long   4
        .long   4
        .long   4
Run Code Online (Sandbox Code Playgroud)

这是gcc 12.2-O3. GCC 也不使用平方和公式。我也不知道为什么它会检查数字是否大于17?但由于某种原因,与 clang 和 rustc 相比,gcc 确实做了很多操作。

这是clang 15.0.0-O3

    test    edi, edi
    jle     .LBB0_1
    lea     eax, [rdi - 1]
    lea     ecx, [rdi - 2]
    imul    rcx, rax
    lea     eax, [rdi - 3]
    imul    rax, rcx
    shr     rax
    imul    eax, eax, 1431655766
    shr     rcx
    lea     ecx, [rcx + 4*rcx]
    add     ecx, eax
    lea     eax, [rcx + 4*rdi]
    add     eax, -3
    ret
.LBB0_1:
        xor     eax, eax
        ret
Run Code Online (Sandbox Code Playgroud)

我不太明白 clang 在那里做了什么样的优化。但 rustc、clang 和 gcc 不喜欢n(n+1)(2n+1)/6

然后我给他们的表演计时。Rust 的表现明显优于 gcc 和 clang。这些是 100 次执行的平均结果。使用第 11 代英特尔酷睿 i7-11800h @ 2.30 GHz

Rust: 0.2 microseconds
Clang: 3 microseconds
gcc: 5 microseconds
Run Code Online (Sandbox Code Playgroud)

有人可以解释一下性能差异吗?

编辑 C++:

int sum_of_squares(int n){
    int sum = 0;
    for(int i = 1; i <= n; i++){
        sum += i*i;
    }
    return sum;
}
Run Code Online (Sandbox Code Playgroud)

编辑2 对于每个想知道这里是我的基准代码的人:

use std::time::Instant;
pub fn sum_of_squares(n: i32) -> i32 {
    let mut sum = 0;
    for i in 1..n+1 {
        sum += i*i;
    }
    sum
}

fn main() {
    let start = Instant::now();
    let result = sum_of_squares(1000);
    let elapsed = start.elapsed();

    println!("Result: {}", result);
    println!("Elapsed time: {:?}", elapsed);
}
Run Code Online (Sandbox Code Playgroud)

在 C++ 中:

#include <chrono>
#include <iostream>

int sum_of_squares(int n){
    int sum = 0;
    for(int i = 1; i <= n; i++){
        sum += i*i;
    }
    return sum;
}

int main() {
    auto start = std::chrono::high_resolution_clock::now();
    int result = sum_of_squares(1000);
    auto end = std::chrono::high_resolution_clock::now();

    std::cout << "Result: " << result << std::endl;
    std::cout << "Elapsed time: "
              << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count()
              << " microseconds" << std::endl;

    return 0;
}
Run Code Online (Sandbox Code Playgroud)

Iwa*_*Iwa 8

我原以为它会使用平方和的公式,但事实并非如此。它使用了一个神奇的数字1431655766,我根本不明白。

LLVM 确实将该循环转换为公式,但它与朴素平方和公式不同。

这篇文章比我更好地解释了公式和生成的代码。


Jér*_*ard 5

Clang 在 C++ 中使用了相同的优化-O3,但尚未在 GCC 中使用。参见GodBolt。AFAIK,默认的 Rust 编译器像 Clang 一样在内部使用 LLVM。这就是为什么他们产生类似的代码。GCC 使用 SIMD 指令矢量化的朴素循环,而 Clang 使用类似于您在问题中给出的公式。

C++代码优化后的汇编代码如下:

sum_of_squares(int):                    # @sum_of_squares(int)
        test    edi, edi
        jle     .LBB0_1
        lea     eax, [rdi - 1]
        lea     ecx, [rdi - 2]
        imul    rcx, rax
        lea     eax, [rdi - 3]
        imul    rax, rcx
        shr     rax
        imul    eax, eax, 1431655766
        shr     rcx
        lea     ecx, [rcx + 4*rcx]
        add     ecx, eax
        lea     eax, [rcx + 4*rdi]
        add     eax, -3
        ret
.LBB0_1:
        xor     eax, eax
        ret
Run Code Online (Sandbox Code Playgroud)

这个优化主要来自于IndVarSimplify优化pass。可以看到,有些变量是用 32 位编码的,而另一些变量是用 33 位编码的(主流平台上需要 64 位寄存器)。代码基本上做了:

if(edi == 0)
    return 0;
eax = rdi - 1;
ecx = rdi - 2;
rcx *= rax;
eax = rdi - 3;
rax *= rcx;
rax >>= 1;
eax *= 1431655766;
rcx >>= 1;
ecx = rcx + 4*rcx;
ecx += eax;
eax = rcx + 4*rdi;
eax -= 3;
return eax;
Run Code Online (Sandbox Code Playgroud)

这可以进一步简化为以下等效的 C++ 代码:

if(n == 0)
    return 0;
int64_t m = n;
int64_t tmp = ((m - 3) * (m - 1) * (m - 2)) / 2;
tmp = int32_t(int32_t(tmp) * int32_t(1431655766));
return 5 * ((m - 1) * (m - 2) / 2) + tmp + (4*m - 3);
Run Code Online (Sandbox Code Playgroud)

请注意,为了清楚起见,忽略了一些强制转换和溢出。

神奇的数字1431655766 来自与除以 3 相关的溢出的一种校正。的确,1431655766 / 2**32 ~= 0.33333333348855376。Clang 利用 32 位溢出来生成公式的快速实现n(n+1)(2n+1)/6


gna*_*729 4

在具有 128 位乘积的机器上除以常数 c 通常是通过乘以 2^64 / c 来实现的。那\xe2\x80\x99s是你奇怪的常量的来源。

\n

现在公式 n(n+1)(2n+1) / 6 对于大的 n 会溢出,而和则不会\xe2\x80\x99t,所以这个公式只能非常非常小心地使用。

\n

  • 不过,代码中没有 128 位乘积(也没有 64 位),它只是一个 32 位乘积(就像将两个 `uint32_t` 彼此相乘),不会生成乘积的“上半部分”。所以它不应该是定点倒数的乘法。但是 1431655766 也有这样的性质: 1431655766 * 3 = 2 (mod 2^32) (3认同)