为什么 np.dot 比 np.sum 快这么多?

Rap*_*ael 47 python numpy simd cython numba

为什么 np.dot 比 np.sum 快这么多?根据这个答案,我们知道 np.sum 很慢并且有更快的替代方案。

\n

例如:

\n
In [20]: A = np.random.rand(1000)\n\nIn [21]: B = np.random.rand(1000)\n\nIn [22]: %timeit np.sum(A)\n3.21 \xc2\xb5s \xc2\xb1 270 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 100,000 loops each)\n\nIn [23]: %timeit A.sum()\n1.7 \xc2\xb5s \xc2\xb1 11.5 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000,000 loops each)\n\nIn [24]: %timeit np.add.reduce(A)\n1.61 \xc2\xb5s \xc2\xb1 19.6 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000,000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

但它们都比以下慢:

\n
In [25]: %timeit np.dot(A,B)\n1.18 \xc2\xb5s \xc2\xb1 43.9 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000,000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

假设 np.dot 都是将两个数组按元素相乘,然后对它们求和,这怎么可能比仅对一个数组求和更快呢?如果 B 设置为全 1 数组,则 np.dot 将简单地对 A 求和。

\n

因此,对 A 求和的最快选项似乎是:

\n
In [26]: O = np.ones(1000)\nIn [27]: %timeit np.dot(A,O)\n1.16 \xc2\xb5s \xc2\xb1 6.37 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000,000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

这不对吧?

\n

这是在 Ubuntu 上使用 numpy 1.24.2,在 Python 3.10.6 上使用 openblas64。

\n

此 NumPy 安装中支持的 SIMD 扩展:

\n
baseline = SSE,SSE2,SSE3\nfound = SSSE3,SSE41,POPCNT,SSE42,AVX,F16C,FMA3,AVX2\n
Run Code Online (Sandbox Code Playgroud)\n

更新

\n

如果数组更长,则时间顺序会颠倒。那是:

\n
In [28]: A = np.random.rand(1000000)\nIn [29]: O = np.ones(1000000)\nIn [30]: %timeit np.dot(A,O)\n545 \xc2\xb5s \xc2\xb1 8.87 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)\nIn [31]: %timeit np.sum(A)\n429 \xc2\xb5s \xc2\xb1 11 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)    \nIn [32]: %timeit A.sum()\n404 \xc2\xb5s \xc2\xb1 2.95 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)\nIn [33]: %timeit np.add.reduce(A)\n401 \xc2\xb5s \xc2\xb1 4.21 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

这对我来说意味着调用 np.sum(A)、A.sum()、np.add.reduce(A) 时存在一些固定大小的开销,而调用 np.dot() 时不存在这些开销,但该部分进行求和的代码实际上更快。

\n

\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2\x80\x94\xe2 \x80\x94\xe2\x80\x94-

\n

使用 cython、numba、python 等的任何加速都会很高兴看到。

\n

use*_*ica 40

numpy.dot此处委托给BLAS 向量-向量乘法,同时numpy.sum使用成对求和例程,切换到块大小为 128 个元素的 8x 展开求和循环。

我不知道您的 NumPy 使用的是什么 BLAS 库,但据我所知,好的 BLAS 通常会利用 SIMD 操作,但numpy.sum不会这样做。代码中任何 SIMD 的使用numpy.sum都必须是编译器自动向量化,这可能比 BLAS 的效率低。

当您将数组大小增加到 100 万个元素时,此时您可能会达到缓存阈值。该dot代码正在处理大约 16 MB 的数据,该sum代码正在处理大约 8 MB 的数据。该dot代码可能会将数据传输到较慢的缓存级别或 RAM,或者可能两者dotsum使用较慢的缓存级别,并且dot性能较差,因为它需要读取更多数据。sum如果我尝试逐渐增加数组大小,则与具有更高的每个元素性能相比,计时与某种阈值效应更加一致。


Jér*_*ard 33

这个答案通过提供额外的细节来完善@user2357112的好答案。这两个功能都得到了优化。话虽这么说,成对求和通常会慢一些,但通常会提供更准确的结果。尽管相对较好,但它也不是最理想的。Windows 上默认使用的 OpenBLAS 不执行成对求和。

\n

以下是 Numpy 代码的汇编代码:

\n

在此输入图像描述

\n

以下是 OpenBLAS 代码的汇编代码:

\n

在此输入图像描述

\n

Numpy 代码的主要问题是它不使用 AVX(256 位 SIMD 指令集),而是使用 SSE(128 位 SIMD 指令集),而不是 OpenBLAS,至少在 1.22.4 版本(我使用的那个) )和之前。更糟糕的是:Numpy 代码中的指令是标量指令!我们最近致力于此工作,最新版本的 Numpy 现在应该使用 AVX。话虽如此,由于成对求和(特别是对于大数组),它可能仍然不如 OpenBLAS 快。

\n

请注意,由于数组太小,这两个函数在开销上花费了不可忽略的时间。可以使用 Numba 中的手写实现来消除此类开销。

\n
\n

如果数组更长,则时间顺序会颠倒。

\n
\n

这是预料之中的。事实上,当函数在缓存中运行时,它们会受到计算限制,但当数组很大并且适合 L3 缓存甚至 RAM 时,它们就会受到内存限制。因此,np.dot对于更大的数组,统计数据会更慢,因为它需要从内存中读取两倍大的数据。更具体地说,它需要从内存中读取8*1000000*2/1024**2 ~= 15.3 MiB,因此您可能需要从 RAM 中读取数据,而 RAM 的吞吐量非常有限。事实上,像我这样的良好双通道 3200 MHz DDR4 RAM 可以达到接近 40 GiB 的实际吞吐量和15.3/(40*1024) ~= 374 \xc2\xb5s. 话虽这么说,顺序代码很难完全饱和此吞吐量,因此顺序达到 30 GiB/s 已经很棒了,更不用说许多主流 PC RAM 在较低频率下运行。30 GHz/s 的吞吐量导致 ~500 \xc2\xb5s,这接近您的计时。同时,由于实现效率低,因此np.sumnp.add.reduce更受计算限制,但要读取的数据量要小两倍,并且实际上可能更适合具有更大吞吐量的 L3 缓存。

\n

为了证明这种效果,你可以简单地尝试运行:

\n
# L3 cache of 9 MiB\n\n# 2 x 22.9 = 45.8 MiB\na = np.ones(3_000_000)\nb = np.ones(3_000_000)\n%timeit -n 100 np.dot(a, a)   #  494 \xc2\xb5s => read from RAM\n%timeit -n 100 np.dot(a, b)   # 1007 \xc2\xb5s => read from RAM\n\n# 2 x 7.6 = 15.2 MiB\na = np.ones(1_000_000)\nb = np.ones(1_000_000)\n%timeit -n 100 np.dot(a, a)   #  90 \xc2\xb5s => read from the L3 cache\n%timeit -n 100 np.dot(a, b)   # 283 \xc2\xb5s => read from RAM\n\n# 2 x 1.9 = 3.8 MiB\na = np.ones(250_000)\nb = np.ones(250_000)\n%timeit -n 100 np.dot(a, a)   # 40 \xc2\xb5s => read from the L3 cache (quite compute-bound)\n%timeit -n 100 np.dot(a, b)   # 46 \xc2\xb5s => read from the L3 cache too (quite memory-bound)\n
Run Code Online (Sandbox Code Playgroud)\n

在我的机器上,L3 的大小只有 9 MiB,因此第二次调用不仅需要读取两倍的数据,而且还需要从较慢的 RAM 中读取比从 L3 缓存中读取的数据更多的数据。

\n

对于小型阵列,L1 缓存的速度非常快,读取数据不应该成为瓶颈。在我的 i5-9600KF 机器上,L1 缓存的吞吐量非常大:~268 GiB/s。这意味着读取两个大小为 1000 的数组的最佳时间是8*1000*2/(268*1024**3) ~= 0.056 \xc2\xb5s。实际上,调用 Numpy 函数的开销远大于此。

\n
\n

快速实施

\n

这是一个快速的 Numba 实现

\n
import numba as nb\n\n# Function eagerly compiled only for 64-bit contiguous arrays\n@nb.njit(\'float64(float64[::1],)\', fastmath=True)\ndef fast_sum(arr):\n    s = 0.0\n    for i in range(arr.size):\n        s += arr[i]\n    return s\n
Run Code Online (Sandbox Code Playgroud)\n

以下是性能结果:

\n
 array items |    time    |  speedup (dot/numba_seq)\n--------------------------|------------------------\n 3_000_000   |   870 \xc2\xb5s   |   x0.57\n 1_000_000   |   183 \xc2\xb5s   |   x0.49\n   250_000   |    29 \xc2\xb5s   |   x1.38\n
Run Code Online (Sandbox Code Playgroud)\n

如果您使用标志parallel=Trueandnb.prange而不是range,Numba 将使用多个线程。这对于大型数组很有用,但对于某些机器上的小型数组可能不适用(由于创建线程和共享工作的开销):

\n
 array items |    time    |  speedup (dot/numba_par)\n--------------------------|--------------------------\n 3_000_000   |   465 \xc2\xb5s   |   x1.06\n 1_000_000   |    66 \xc2\xb5s   |   x1.36\n   250_000   |    10 \xc2\xb5s   |   x4.00\n
Run Code Online (Sandbox Code Playgroud)\n

正如预期的那样,Numba 对于小型数组来说速度更快(因为 Numpy 调用开销大​​部分被消除),并且对于大型数组来说可以与 OpenBLAS 竞争。Numba 生成的代码非常高效:

\n
.LBB0_7:\n        vaddpd  (%r9,%rdx,8), %ymm0, %ymm0\n        vaddpd  32(%r9,%rdx,8), %ymm1, %ymm1\n        vaddpd  64(%r9,%rdx,8), %ymm2, %ymm2\n        vaddpd  96(%r9,%rdx,8), %ymm3, %ymm3\n        vaddpd  128(%r9,%rdx,8), %ymm0, %ymm0\n        vaddpd  160(%r9,%rdx,8), %ymm1, %ymm1\n        vaddpd  192(%r9,%rdx,8), %ymm2, %ymm2\n        vaddpd  224(%r9,%rdx,8), %ymm3, %ymm3\n        vaddpd  256(%r9,%rdx,8), %ymm0, %ymm0\n        vaddpd  288(%r9,%rdx,8), %ymm1, %ymm1\n        vaddpd  320(%r9,%rdx,8), %ymm2, %ymm2\n        vaddpd  352(%r9,%rdx,8), %ymm3, %ymm3\n        vaddpd  384(%r9,%rdx,8), %ymm0, %ymm0\n        vaddpd  416(%r9,%rdx,8), %ymm1, %ymm1\n        vaddpd  448(%r9,%rdx,8), %ymm2, %ymm2\n        vaddpd  480(%r9,%rdx,8), %ymm3, %ymm3\n        addq    $64, %rdx\n        addq    $-4, %r11\n        jne     .LBB0_7\n
Run Code Online (Sandbox Code Playgroud)\n

话虽如此,它并不是最佳的:LLVM-Lite JIT 编译器使用 4x 展开,而 8x 展开在我的 Intel CoffeeLake 处理器上应该是最佳的。事实上,指令的延迟vaddpd是 4 个周期,而每个周期可以执行 2 条指令,因此需要 8 个寄存器来避免停顿以及生成的代码受延迟限制。此外,该汇编代码在 Intel Alderlake 和 Sapphire Rapids 处理器上是最佳的,因为它们的延迟降低了一倍vaddpd。让 FMA SIMD 处理单元饱和绝非易事。我认为编写更快的函数的唯一方法是使用 SIMD 内在函数编写 (C/C++) 本机代码,尽管它的可移植性较差。

\n

请注意,Numba 代码不支持 NaN 或 Inf 值等特殊数字(据fastmath我所知 OpenBLAS 支持)。实际上,它应该仍然可以在 x86-64 机器上工作,但这并不能保证。此外,Numba 代码对于非常大的数组来说在数值上不稳定。Numpy 代码应该是这三个变体中数值最稳定的(然后是 OpenBLAS 代码)。您可以按块计算总和以提高数值稳定性,但这会使代码更加复杂。天下没有免费的午餐。

\n