在 Zen 2 CPU 上使用 AVX2 实现的 GEMM 内核比 AVX2/FMA 更快

Eti*_*e M 6 assembly simd avx micro-optimization matrix-multiplication

我尝试过加快玩具 GEMM 的实施速度。我处理 32x32 双精度块,为此我需要优化的 MM 内核。我可以访问 AVX2 和 FMA。

我在下面定义了两个代码(在 ASM 中,我为格式的粗糙性表示歉意),一个使用 AVX2 功能,另一个使用 FMA。

在不进行微观基准测试的情况下,我想尝试(理论上)理解为什么 AVX2 实现比 FMA 版本快 1.11 倍。以及可能如何改进这两个版本。

下面的代码适用于 3000x3000 双打 MM,并且内核是使用经典的朴素 MM 和可互换的最深循环来实现的。我使用 Ryzen 3700x/Zen 2 作为开发 CPU。

我没有尝试过积极展开,担心 CPU 可能会耗尽物理寄存器。

AVX2 32x32 MM 内核:

Block 82:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 83:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
    vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
    vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
    vaddpd ymm2, ymm10, ymm2    
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
    vaddpd ymm3, ymm11, ymm3    
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
    vaddpd ymm0, ymm9, ymm0   
    vaddpd ymm1, ymm12, ymm1
    vaddpd ymm4, ymm10, ymm4
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
    vmulpd ymm8, ymm8, ymmword ptr [rax]       
    vaddpd ymm5, ymm11, ymm5    
    add rax, 0x5dc0 
    vaddpd ymm6, ymm10, ymm6
    vaddpd ymm7, ymm8, ymm7 
    cmp r13, 0x20
    jnz 0x140004530 <Block 83>
Block 84:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400044d0 <Block 82>
Run Code Online (Sandbox Code Playgroud)

AVX2/FMA 32x32 MM 内核:

Block 80:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 81:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
    vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
    vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
    vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
    vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
    vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
    vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
    vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
    add rax, 0x5dc0 
    cmp r13, 0x20   
    jnz 0x140004450
Block 82:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400043f0 <Block 80>
Run Code Online (Sandbox Code Playgroud)

Pet*_*des 7

Zen2 对于 具有 3 个周期延迟vaddpd,对于 具有 5 个周期延迟vfma...pd。(https://uops.info/)。

具有 8 个累加器的代码具有足够的 ILP,您预计每个时钟有接近两个 FMA,大约每 5 个时钟 8 个(如果没有其他瓶颈),这比 10/5 理论最大值要小一些。

vaddpd实际上在 Zen2 上的不同vmulpd端口上运行(与 Intel 不同),分别是端口 FP2/3 和 FP0/1,因此理论上它可以支持 2/clock. 由于循环承载依赖的延迟较短,如果调度不让一个 dep 链落后,8 个累加器就足以隐藏延迟。(但至少乘法不会窃取它的周期。)vaddpd vmulpdvaddpd

Zen2 的前端有 5 个指令宽(如果存在多微指令,则为 6 个微指令),并且它可以将内存源指令解码为单个微指令。因此,对于非 FMA 版本,它很可能会每次乘法和加法执行 2 个时钟周期。

如果您可以展开 10 或 12,则可能会隐藏足够的 FMA 延迟并使其与非 FMA 版本相同,但功耗更低,并且对在其他逻辑核心上运行的代码更适合 SMT。(10 = 5 x 2 勉强够用这意味着任何调度缺陷都会在关键路径上的 dep 链上失去进度。请参阅为什么 mulss 在 Haswell 上只需要 3 个周期,与 Agner 的指令表不同?(展开 FP具有多个累加器的循环)用于在英特尔上进行一些测试。)

(相比之下,Intel Skylake 在相同端口上运行 vaddpd/vmulpd,延迟与 vfma...pd 相同,全部延迟为 4c,吞吐量为 0.5c。)

我没有非常仔细地查看你的代码,但是 10 个 YMM 向量可能是接触两对缓存行与接触 5 个总行之间的权衡,如果空间预取器尝试完成对齐对,这可能会更糟。或者可能还好。12 个 YMM 向量将是三对,这应该没问题。

根据矩阵大小,乱序 exec 可能能够在外循环的单独迭代之间重叠内循环 dep 链,特别是如果循环退出条件可以更快地执行并在 FP 工作时解决错误预测(如果有)仍在飞行中。这是相同工作的总 uops 较少的优势,有利于 FMA。

  • @JohnDMcCalpin 在 Zen 2 上,当代码像 OP 的示例一样执行 50% 加法和 50% 乘法时,总限制为 4/时钟。CPU 每个时钟最多可以调度 6 个指令,因此 OP 的数学运算的 4 个/时钟仍然可以满足该特定瓶颈。对于 Intel,无论是加法、乘法还是 fma 指令,它都是 2/时钟。在 Intel 上,FMA 可能是这些周期的最佳用途。 (2认同)