具有复杂 numpy 数组和本机数据类型的 numba TypingError

lea*_*ner 5 python numpy python-3.x numba

我有一个处理复杂数据类型的函数,我正在使用它numba来加快处理速度。我声明了一个零数组numpy,使用复杂数据类型,稍后在函数中填充。但在运行时numba不能使零发生函数过载。为了重现错误,我提供了一个 MWE。

import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex)
    a[idx] = 10
    return a

my_func(4)
Run Code Online (Sandbox Code Playgroud)

以下错误显示在a初始化数组的位置。

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)

No implementation of function Function(<built-in function zeros>) found for signature:
zeros(Tuple(Literal[int](10), Literal[int](5)), dtype=Function(<class 'complex'>))
There are 2 candidate implementations:

 Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 511.
    With argument(s): '(UniTuple(int64 x 2), dtype=Function(<class 'complex'>))':
   No match.
Run Code Online (Sandbox Code Playgroud)

我假设这与变量的数据类型有关a(我需要它很复杂)。我该如何解决这个错误?

任何帮助将不胜感激,谢谢。

小智 1

您的问题与复数无关。如果您指定了a = np.zeros((10, 5), dtype=int),您会遇到同样的问题。

虽然numpy采用 python 本机数据类型int, floatandcomplex将它们视为np.int32, np.float64and np.complex128,但numba它本身并不会这样做。

因此,每当您在 jitted 函数中指定数据类型时,您要么使用numpy数据类型:

import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=np.complex128)
    a[idx] = 10
    return a

my_func(4)
Run Code Online (Sandbox Code Playgroud)

或者您可以通过直接导入来使用numba数据类型:

import numpy as np
from numba import njit, complex128

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex128)
    a[idx] = 10
    return a

my_func(4)
Run Code Online (Sandbox Code Playgroud)

或通过types

import numpy as np
from numba import njit, types

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=types.complex128)
    a[idx] = 10
    return a

my_func(4)
Run Code Online (Sandbox Code Playgroud)

据我所知,您使用这些选项中的哪一个确实没有什么区别。是 numba 文档的相关部分。