为什么 numba 无法编译这个 for 循环?

Unl*_*kus 5 python numba python-3.6

我尝试计算一维离散余弦变换(类型 2),我试图用 numba 提高我的性能。我有以下代码:

import numpy as np
import math
import numba

@numba.jit()
def epsilon(N:int, i: int) -> float:
    if i == 0 or i == N:
        return math.sqrt(2)/2
    return 1.0

@numba.jit()
def dct2(a):
    n = len(a)
    y = np.empty([2*n])
    y[:len(a)] = a
    y[n:] = np.flip(a)
    fft = np.fft.fft(y)
    erg = np.empty([n])
    factor = 1/math.sqrt(2*n)
    for i in range(0,n):
        erg[i] = factor*epsilon(n,i)*(math.cos(-i*2*math.pi/(4*n))*fft[i].real - math.sin(-i*2*math.pi/(4*n))*fft[i].imag)
    return erg
Run Code Online (Sandbox Code Playgroud)

我认为它无法编译for循环,但我不知道为什么。根据我对 numba 文档的理解,应该能够解除循环。

我收到以下警告:

In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function empty>)
[2] During: typing of call at src/algos.py (32)


File "src/algos.py", line 32:
def dct2(a):
    <source elided>
    n = len(a)
    y = np.empty([2*n])
 ^

  @numba.jit()
src/algos.py:29: NumbaWarning: Function "dct2" failed type inference: cannot determine Numba type of <class 'numba.dispatcher.LiftedLoop'>

File "src/algos.py", line 39:
def dct2(a):
    <source elided>
    factor = 1/math.sqrt(2*n)
    for i in range(0,n):
 ^

  @numba.jit()
src/algos.py:29: NumbaWarning: Function "dct2" was compiled in object mode without forceobj=True, but has lifted loops.
  @numba.jit()
src/algos.py:29: NumbaWarning: Function "dct2" failed type inference: Invalid use of Function(<built-in function empty>) with argument(s) of type(s): (list(int64))
 * parameterized
Run Code Online (Sandbox Code Playgroud)

有谁知道循环失败的原因以及如何修复它?

Koc*_*cas 1

首先,正如@Eskapp指出的,numba不支持一些numpy方法,其中显然包括fft。其次,numba 对数据类型非常挑剔(因为它被编译成静态类型代码),因此您应该具体到最大值。

我会避免在不提供类型的情况下启动数组,并执行以下操作(省略 fft 部分):

def dct2(a):
    n = len(a)
    y = np.empty(2*n, np.float64)
    y[:len(a)] = a
    y[n:] = np.flip(a)
    # fft = np.fft.fft(y)
    erg = np.empty(n, np.float64)
    factor = 1 / math.sqrt(2*n)
    for i in range(n):
        erg[i] = factor*epsilon(n,i)*(math.cos(-i*2*math.pi/(4*n)) - math.sin(-i*2*math.pi/(4*n)))
    return erg
Run Code Online (Sandbox Code Playgroud)

我希望它能够编译并运行。

另请注意,numba 支持的任何 numpy 方法在 numba 端都有特定的实现 - 包括被调用者类型。因此,使用“整数列表”(如 [2 * n] 而不是 2 * n)的参数调用 np.empty() ,将会引发一个非常不言自明的键入错误,如下所示,因为该签名未实现:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)No implementation of function Function(<built-in function empty>) found for signature: 
>>>empty(list(int64)<iv=None>, class(float64))
Run Code Online (Sandbox Code Playgroud)