我应该如何使用数组(或元组)作为 Numba 类型字典的键和值?

use*_*897 4 numba

我有以下代码尝试将键值对存储到 numba 字典中。Numba 的官方页面说新的类型字典支持数组作为键,但我无法让它工作。错误消息表明密钥不能是哈希值。知道如何让它发挥作用吗?

In [7]: from numba.typed import Dict 
   ...: from numba import types 
   ...: import numpy as np        

In [15]: dd = Dict.empty(key_type=types.int32[::1], value_type=types.int32[::1],)                                                                                                                                  

In [16]: key = np.asarray([1,2,3], dtype=np.int32)                                                                                                                                                                 

In [17]: dd[key] = key   
Run Code Online (Sandbox Code Playgroud)

错误信息:

TypingError:在 nopython 模式管道中失败(步骤:nopython 前端)数组类型的未知属性“哈希”(int32,1d,C)

编辑:我可能错过了一些东西。我可以在解释器中使用 types.UniTuple (没有 @jit 装饰器)。但是,当我将以下函数放入脚本 a.py 并使用命令“python a.py”运行它时,我收到 UniTuple 未找到错误。

@jit(nopython=True)
def go_fast2(date, starttime, id, tt, result): # Function is compiled and runs in machine code
    prev_record = Dict.empty(key_type=types.UniTuple(types.int64, 2),  value_type=types.UniTuple(types.int64, 3),)
    for i in range(1, length):
        key = np.asarray([date[i], id[i]], dtype=np.int64)
        thistt = tt[i]
        thistime = starttime[i]
        if key in prev_record:
            prev_time = prev_record[key][0]
            prev_tt = prev_record[key][1]
            prev_res = prev_record[key][2]
            if thistt == prev_tt and thistime - prev_time <= 30 * 1000 * 1000: # with in a 10 seconds window
                result[i] = prev_res + 1
            else:
                result[i] = 0
            prev_record[key] = np.asarray((thistime, thistt, result[i]), dtype=np.int64)
        else:
            result[i] = 0
            prev_record[key] = np.asarray((thistime, thistt, result[i]), dtype=np.int64)
    return 
Run Code Online (Sandbox Code Playgroud)

Jos*_*del 7

当前的文档说:

可接受的键/值类型包括但不限于:unicode 字符串、数组、标量、元组。

从措辞来看,您似乎可以使用数组作为键类型,但这是不正确的,因为数组是可变的,因此不可散列。它也不适用于标准的 python 字典。您可以将数组转换为元组,这样就可以了:

dd = Dict.empty(
    key_type=types.UniTuple(types.int64, 3), 
    value_type=types.int64[::1],)
key = np.asarray([1,2,3], dtype=np.int64)
dd[tuple(key)] = key
Run Code Online (Sandbox Code Playgroud)

请注意,int32您之前使用的 dtype 无法在 64 位计算机上运行,​​因为在调用tuple()数组时,int32 的元组将自动转换为 int64。

另一个问题是元组具有固定大小,因此您不能使用任意大小的数组作为键。