Numba @jit无法优化简单的功能

Kar*_*Man 2 python numpy numba

我有一个非常简单的函数,它使用Numpy数组和for循环,但添加Numba @jit装饰器绝对没有加速:

# @jit(float64[:](int32,float64,float64,float64,int32))
@jit
def Ising_model_1D(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    beta = 1/T
    s = randn(N,1) > 10  
    s[N-1] = s[0]
    mag = zeros((n_iter,1))
    aux_idx =  randint(low=0,high=N,size=(n_iter,1))

    for i1 in arange(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx]*2 - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1

        if(delta_E < 0):
            s[rnd_idx] = np.logical_not(s[rnd_idx]) 
        elif(np.exp(-1*beta*delta_E) >= rand()):
            s[rnd_idx] = np.logical_not(s[rnd_idx])
        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N 
    return mag
Run Code Online (Sandbox Code Playgroud)

另一方面,MATLAB运行时间不到0.5秒!为什么Numba遗漏了这么基本的东西?

jme*_*jme 8

这是在我的机器上运行大约0.4秒的代码的重新处理:

def ising_model_1d(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    n_iter = int(n_iter)
    beta = 1/T
    s = randn(N) > 10
    s[N-1] = s[0]

    mag = zeros(n_iter)
    aux_idx =  randint(low=0,high=N,size=n_iter)

    pre_rand = rand(n_iter)

    _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag)

    return mag


@jit(nopython=True)
def _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag):
    for i1 in range(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx*2] - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1
        t = rand()
        if delta_E < 0:
            s[rnd_idx] = not s[rnd_idx]
        elif np.exp(-1*beta*delta_E) >= pre_rand[i1]:
            s[rnd_idx] = not s[rnd_idx]

        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N
Run Code Online (Sandbox Code Playgroud)

请确保结果符合预期!我改变了很多你所拥有的,并不能保证计算是正确的!

工作numba需要一点点小心.Python函数以及大多数numpy函数都无法通过编译器进行优化.我觉得有用的一件事是使用nopython选项@jit.这意味着只要你给它一些无法真正优化的代码,编译器就会抱怨.然后,您可以查看错误消息并找到可能会降低代码速度的行.

我发现,技巧是在Python中编写一个"网关"函数,尽可能多地使用numpy及其矢量化函数.它应该创建你需要存储结果的空数组.它应该打包你在计算过程中需要的所有数据.然后它应该将所有这些传递到一个大而长的参数列表中的jitted函数中.

例证:注意我如何处理jitted代码中的随机数生成.在原始代码中,您调用了rand():

elif(np.exp(-1*beta*delta_E) >= rand()):
Run Code Online (Sandbox Code Playgroud)

但是rand()无法优化numba(numba至少在旧版本中.在新版本中它可以,只要在rand没有参数的情况下调用).观察结果是每个n_iter迭代都需要一个随机数.所以我们只需numpy在包装函数中创建一个随机数组,然后将这个随机数组提供给jitted函数.获取随机数就像索引到此数组一样简单.

最后,有关numpy可由最新版本编译器优化的函数列表,请参见此处.在我重新编写代码时,我积极地删除对numpy函数的调用,以便代码可以在更多版本的函数上运行numba.