如何优化 NumPy 埃拉托色尼筛?

Ξέν*_*νος 0 python primes numpy sieve-of-eratosthenes python-3.x

我在 NumPy 中实现了自己的埃拉托斯特尼筛法。我相信你们都知道它是为了找到一个数字以下的所有素数,所以我不会进一步解释。

\n

代码:

\n
import numpy as np\n\ndef primes_sieve(n):\n    primes = np.ones(n+1, dtype=bool)\n    primes[:2] = False\n    primes[4::2] = False\n    for i in range(3, int(n**0.5)+1, 2):\n        if primes[i]:\n            primes[i*i::i] = False\n\n    return np.where(primes)[0]\n
Run Code Online (Sandbox Code Playgroud)\n

正如你所看到的,我已经做了一些优化,首先除了 2 之外所有素数都是奇数,所以我将 2 的所有倍数设置为False且仅是暴力奇数。

\n

其次,我只循环遍历直到平方根下限的数字,因为平方根之后的所有合数都会因平方根以下素数的倍数而被消除。

\n

但这不是最佳的,因为它循环遍历低于限制的所有奇数,并且并非所有奇数都是质数。随着数字的增大,素数变得更加稀疏,因此存在大量冗余迭代。

\n

因此,如果候选列表是动态更改的,以这样的方式,已经识别的合数甚至不会被迭代,因此只有质数被循环,不会有任何浪费的迭代,因此算法将是最优的。

\n

我写了一个优化版本的粗略实现:

\n
def primes_sieve_opt(n):\n    primes = np.ones(n+1, dtype=bool)\n    primes[:2] = False\n    primes[4::2] = False\n    limit = int(n**0.5)+1\n    i = 2\n    while i < limit:\n        primes[i*i::i] = False\n        i += 1 + primes[i+1:].argmax()\n\n    return np.where(primes)[0]\n
Run Code Online (Sandbox Code Playgroud)\n

但它比未优化的版本慢得多:

\n
In [92]: %timeit primes_sieve(65536)\n271 \xc2\xb5s \xc2\xb1 22 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)\n\nIn [102]: %timeit primes_sieve_opt(65536)\n309 \xc2\xb5s \xc2\xb1 3.86 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 1,000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

我的想法很简单,通过获取 的下一个索引True,我可以确保覆盖所有素数并且仅处理素数。

\n

np.argmax在这方面进展缓慢。我在 Google 上搜索了“如何在 NumPy 数组中查找下一个 True 值的索引”(不带引号),我发现了几个 StackOverflow 问题,这些问题稍微相关,但最终没有回答我的问题。

\n

例如,numpy 获取 value 为 true 的索引,并且Numpy 第一次出现 value 大于现有 value

\n

我并不是试图找到 的所有索引True,这样做非常愚蠢,我需要找到下一个True值,获取其索引并立即停止循环,只有bools。

\n

我该如何优化这个?

\n
\n

编辑

\n

如果有人感兴趣,我进一步优化了我的算法:

\n
import numba\nimport numpy as np\n\n@numba.jit(nopython=True, parallel=True, fastmath=True, forceobj=False)\ndef prime_sieve(n: int) -> np.ndarray:\n    primes = np.full(n + 1, True)\n    primes[:2] = False\n    primes[4::2] = False\n    primes[9::6] = False\n    limit = int(n**0.5) + 1\n    for i in range(5, limit, 6):\n        if primes[i]:\n            primes[i * i :: 2 * i] = False\n\n    for i in range(7, limit, 6):\n        if primes[i]:\n            primes[i * i :: 2 * i] = False\n\n    return np.flatnonzero(primes)\n
Run Code Online (Sandbox Code Playgroud)\n

我过去常常numba加快速度。由于除了 2 和 3 之外的所有素数都是 6k+1 或 6k-1,这使得速度更快。

\n

Nic*_*ell 5

我的想法很简单,通过获取 True 的下一个索引,我可以确保覆盖所有素数并且仅处理素数。

一些分析表明,通过这种方式最多可以获得 0.2% 的加速。

对于较大的 N 值,绝大多数时间都花在了 上primes[i*i::i] = False

以下是 line_profiler 在前一亿个素数上运行的输出:

Timer unit: 1e-09 s

Total time: 1.04878 s
File: /tmp/ipykernel_22262/2557137730.py
Function: primes_sieve at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def primes_sieve(n):
     4         1   14264754.0 14264754.0      1.4      primes = np.ones(n+1, dtype=bool)
     5         1      12394.0  12394.0      0.0      primes[:2] = False
     6         1   16238905.0 16238905.0      1.5      primes[4::2] = False
     7      4999    1309955.0    262.0      0.1      for i in range(3, int(n**0.5)+1, 2):
     8      3771    1507909.0    399.9      0.1          if primes[i]:
     9      1228  909007228.0 740233.9     86.7              primes[i*i::i] = False
    10                                           
    11         1  106434647.0 106434647.0     10.1      return np.where(primes)[0]
Run Code Online (Sandbox Code Playgroud)

如果您跳过更多 的值i,则可以避免在for i in range(3, int(n**0.5)+1, 2):和行上花费时间if primes[i]:。但你无法避免花费在 的时间primes[i*i::i] = False。由于程序在每一项上花费 0.1%,因此您最多可以节省 0.2% 的执行时间。