Numba 字典:JIT() 装饰器中的签名

Yev*_*kin 5 python jit dictionary numpy numba

我的函数将一个 numpy 数组列表和一个字典(或一个字典列表)作为输入参数,并返回一个值列表。numpy 数组的列表很长,并且数组可能具有不同的形状。虽然我可以单独传递 numpy 数组,但出于管理目的,我真的很想形成一个 numpy 数组元组并将它们这样传递到我的函数中。没有字典(根据 numba >=0.43 专门形成)整个设置工作正常 - 请参阅下面的脚本。因为输入和输出的结构是元组形式,所以JIT需要签名——没有它它无法确定数据结构的类型。但是,无论我如何尝试将字典“d”声明为 JIT 装饰器,我都无法使脚本正常工作。请帮助提供想法或解决方案(如果存在)。

非常感谢

'''Python:

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

@njit(  'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ))'  )

def somefunction(lst_arr):
    arr1, arr2 = lst_arr

    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ,prod)
    return result

a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
print(a,b)

print(somefunction(arg))


# ~~ The Dict.empty() constructs a typed dictionary.
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,)

d['k1'] = 1.5
d['k2'] = 0.5
Run Code Online (Sandbox Code Playgroud)

'''

我希望将 'd'-dictionary 传递给 'somefunction' 并在内部使用它与 dict 键...形式示例如下: result = (summ * d['k1'], prod * d['k2'])

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

@njit(  'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ), Dict)'  )

def somefunction(lst_arr, mydict):
    arr1, arr2 = lst_arr

    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ*mydict['k1'],prod*mydict['k2'])
    return result

# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)

# ~~ Input dictionary for the function 
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64)

d['k1'] = 1.5
d['k2'] = 0.5


# ~~ Run function and print results
print(somefunction(arg, d))
Run Code Online (Sandbox Code Playgroud)

Moh*_*hif 1

我正在使用的版本0.45.1。您可以简单地传递字典,而无需在字典中声明类型:

d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64[:],
)
d['k1'] = np.arange(5) + 1.0
d['k2'] = np.arange(5) + 11.0

# Numba will infer the type on it's own.
@njit
def somefunction2(d):
    prod = 1

    # I am assuming you want sum of second array and product of second
    result = (d['k2'].sum(), d['k1'].prod())

    return result

print(somefunction(d))
# Output : (65.0, 120.0)
Run Code Online (Sandbox Code Playgroud)

作为参考,您可以从官方文档中查看此示例。

更新
在您的情况下,您可以简单地jit推断它自己的类型,它应该可以工作,以下代码对我有用:

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
from numba.types import DictType

# Let jit infer the types on it's own
@njit
def somefunction(lst_arr, mydict):
    arr1, arr2 = lst_arr
    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ*mydict['k1'],prod*mydict['k2'])
    return result

# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(10)+11.0  #<--------------- This is of different shape 
arg = (a,b)

# ~~ Input dictionary for the function 
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64)

d['k1'] = 1.5
d['k2'] = 0.5


# This works now
print(somefunction(arg, d))
Run Code Online (Sandbox Code Playgroud)

您可以在这里查看官方文档:

除非必要,建议让 Numba 使用 @jit 的无签名变体来推断参数类型。

我尝试了各种方法,但这是唯一适用于您指定的问题的方法。