Numba:使用具有默认值的参数调用带有显式签名的 jit

gio*_*nni 7 python jit types numba

我正在使用 numba 来制作一些包含 numpy 数组循环的函数。

一切都很好,花花公子,我可以使用jit并且我学会了如何定义签名。

现在我尝试在带有可选参数的函数上使用 jit,例如:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b
Run Code Online (Sandbox Code Playgroud)

这有效,但如果不是optional(float)我使用optional(float64)它,则无效(与int或相同int64)。我花了 1 个小时试图弄清楚这个语法(实际上,我的一个朋友偶然发现了这个解决方案,因为他忘记64在浮动之后写),但是,为了我的爱,我不明白为什么会这样。我在互联网上找不到任何东西,而且 numba 关于该主题的文档充其量是稀缺的(并且他们指定optional应该采用 numba 类型)。

有谁知道这是如何工作的?我错过了什么?

MSe*_*ert 5

啊,但是异常消息应该给出一个提示:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)
Run Code Online (Sandbox Code Playgroud)

That means optional is the wrong choice here. In fact optional represents None or "that type". But you want an optional argument, not an argument that could be a float and None, e.g.:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'
Run Code Online (Sandbox Code Playgroud)

I suspect that it just "happens" to work for optional(float) because float is just an "arbitary Python object" from numbas point of view, so with optional(float) you could pass anything in there (this apparently includs not giving the argument). With optional(float64) it could only be None or a float64. That category isn't broad enough to allow not providing the argument.

It works if you give the type Omitted:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0
Run Code Online (Sandbox Code Playgroud)

However it seems like Omitted isn't actually included in the documentation and that it has some "rough edges". For example it can't be compiled in nopython mode with that signature, even though it seems possible without signature:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0
Run Code Online (Sandbox Code Playgroud)