fug*_*ede 5 python jit numpy collatz numba
考虑使用以下函数计算 3 n + 1 问题的给定输入的步数:
def num_steps(b, steps):
e = b
d = 0
while True:
if e == 1:
d += steps[e]
return d
if e % 2 == 0:
e //= 2
else:
e = 3*e + 1
d += 1
Run Code Online (Sandbox Code Playgroud)
在这里,steps存在允许对结果进行记忆,但为了这个问题,我们只注意到只要steps[1] == 0,它应该没有影响,因为在这种情况下, 的影响d += steps[e]是将 0 添加到d。事实上,下面的例子给出了预期的结果:
import numpy as np
steps = np.array([0, 0, 0, 0])
print(num_steps(3, steps)) # Prints 7
Run Code Online (Sandbox Code Playgroud)
但是,如果我们使用numba.jit(或njit)对方法进行 JIT 编译,我们将不再得到正确的结果:
import numpy as np
from numba import jit
steps = np.array([0, 0, 0, 0])
print(jit(num_steps)(3, steps)) # Prints 0
Run Code Online (Sandbox Code Playgroud)
如果我们d += steps[e]在编译方法之前删除看似冗余的部分,我们确实会得到正确的结果。我们甚至可以放入一个print(steps[e])befored += steps[e]并查看其值为 0。我也可以将 移动d += 1到循环的顶部(并进行初始化d = -1)以获得在 Numba 情况下也可以使用的东西。
Python 3.8(通过标准 conda 通道提供的最新版本)上的 Numba 0.48.0 (llvmlite 0.31.0) 会发生这种情况。
对我来说,这看起来像是一个错误,与steps[e]. 如果你设置了,parallel=True那就是 Numba 崩溃的地方。您可以在 Numba github 存储库中创建一个问题,也许开发人员可以解释它。
如果我重写该函数以避免最终的就地增量,它对我有用:
@numba.njit
def numb_steps(b, steps):
e = b
d = 0
while True:
if e == 1:
return d + steps[e]
if e % 2 == 0:
e //= 2
else:
e = 3*e + 1
d += 1
Run Code Online (Sandbox Code Playgroud)
和:
python 3.7.6
numba 0.47.0
Run Code Online (Sandbox Code Playgroud)