将某种类型的 0 添加到局部变量时,Numba 抖动会改变结果

fug*_*ede 5 python jit numpy collatz numba

考虑使用以下函数计算 3 n + 1 问题的给定输入的步数:

def num_steps(b, steps):
    e = b
    d = 0
    while True:
        if e == 1:
            d += steps[e]
            return d
        if e % 2 == 0:
            e //= 2
        else:
            e = 3*e + 1
        d += 1
Run Code Online (Sandbox Code Playgroud)

在这里,steps存在允许对结果进行记忆,但为了这个问题,我们只注意到只要steps[1] == 0,它应该没有影响,因为在这种情况下, 的影响d += steps[e]是将 0 添加到d。事实上,下面的例子给出了预期的结果:

import numpy as np

steps = np.array([0, 0, 0, 0])
print(num_steps(3, steps))  # Prints 7
Run Code Online (Sandbox Code Playgroud)

但是,如果我们使用numba.jit(或njit)对方法进行 JIT 编译,我们将不再得到正确的结果:

import numpy as np
from numba import jit

steps = np.array([0, 0, 0, 0])
print(jit(num_steps)(3, steps))  # Prints 0
Run Code Online (Sandbox Code Playgroud)

如果我们d += steps[e]在编译方法之前删除看似冗余的部分,我们确实会得到正确的结果。我们甚至可以放入一个print(steps[e])befored += steps[e]并查看其值为 0。我也可以将 移动d += 1到循环的顶部(并进行初始化d = -1)以获得在 Numba 情况下也可以使用的东西。

Python 3.8(通过标准 conda 通道提供的最新版本)上的 Numba 0.48.0 (llvmlite 0.31.0) 会发生这种情况。

Rut*_*ies 2

对我来说,这看起来像是一个错误,与steps[e]. 如果你设置了,parallel=True那就是 Numba 崩溃的地方。您可以在 Numba github 存储库中创建一个问题,也许开发人员可以解释它。

如果我重写该函数以避免最终的就地增量,它对我有用:

@numba.njit
def numb_steps(b, steps):

    e = b    
    d = 0

    while True:

        if e == 1:
            return d + steps[e]

        if e % 2 == 0:
            e //= 2
        else:
            e = 3*e + 1

        d += 1
Run Code Online (Sandbox Code Playgroud)

和:

python                    3.7.6
numba                     0.47.0
Run Code Online (Sandbox Code Playgroud)