当多维形状是列表时,Numba 函数中的 NumPy 零不起作用

ind*_*lue 4 python numpy multidimensional-array numba

我尝试尝试使用 numba,因为有人告诉我它非常适用于数值/科学计算应用程序。但是,似乎我已经在以下场景中遇到了问题:

我有一个函数,它计算一个 12x12 的雅可比矩阵,由一个 numpy 数组表示,然后返回这个雅可比矩阵。但是,当我尝试用 装饰所述函数时@numba.njit,出现以下错误:

这通常不是 Numba 本身的问题,而是通常由使用不受支持的功能或解析类型的问题引起的。

作为我使用的一个基本示例,以下代码尝试声明一个 12x12 numpy 零矩阵,但它失败了:

import numpy as np
import numba

@numba.njit
def numpy_matrix_test():
    A = np.zeros([12,12])
    return A

A_out = numpy_matrix_test()
print(A_out)
Run Code Online (Sandbox Code Playgroud)

由于我认为以这种方式声明 numpy 数组很常见以至于 numba 能够处理它们,因此我感到非常惊讶。

MSe*_*ert 8

假设在 numba jitted 函数中调用的函数在未在 numba 函数中使用时是相同的函数实际上是错误的(但可以理解)。实际上,numba(在幕后)委托给它自己的函数,而不是使用“真正的”NumPy 函数。

所以实际上并不是np.zeros在 jitted 函数中调用的,而是它们自己的函数。所以 Numba 和 NumPy 之间的一些差异是不可避免的。

例如,您不能对形状使用列表,它必须是一个元组(列表和数组会产生您遇到的异常)。所以正确的语法是:

@numba.njit
def numpy_matrix_test():
    A = np.zeros((12, 12))
    return A
Run Code Online (Sandbox Code Playgroud)

类似的东西适用于 dtype 参数。它必须是真正的NumPy/numba 类型,不能使用 Python 类型:

@numba.njit
def numpy_matrix_test():
    A = np.zeros((12, 12), dtype=int)  # to make it work use numba.int64 instead of int here
    return A
Run Code Online (Sandbox Code Playgroud)

即使“普通”NumPy 允许:

np.zeros((12, 12), dtype=int)
Run Code Online (Sandbox Code Playgroud)