Numba - nopython 模式是否支持元组列表?

Lae*_*ven 5 python arrays optimization numpy numba

我想澄清一下,这是我第一次使用 Numba,所以我距离成为专家还很远。我正在尝试手动实现一个简单的 KNN,代码如下:

@jit(nopython=True)
def knn(training_set, test_set):
for q in range(len(test_set)):
    indexes = [-1]
    values = [np.inf]
    thres = values[-1]

    for u in range(len(training_set)):
        dist = 0
        flag = False
        dist = knn_dist(training_set[u], test_set[q], thres)
        if dist == 0:
            flag = True
        if not flag:

            '''
            Binary search to obtain the index
            '''    

            # Various code

return
Run Code Online (Sandbox Code Playgroud)

现在,我想使用numba的nopython模式来优化代码,以下是部分错误:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
 in _call_incref_decref(self, builder, root_type, typ, value, funcname, getters)
    185             try:
--> 186                 meminfo = data_model.get_nrt_meminfo(builder, value)
    187             except NotImplementedError as e:

 in get_nrt_meminfo(self, builder, value)
    328                 raise NotImplementedError(
--> 329                     "unsupported nested memory-managed object")
    330         return value

NotImplementedError: unsupported nested memory-managed object
Run Code Online (Sandbox Code Playgroud)

训练集和测试集都是元组列表的列表,我想知道nopython是否支持这种数据结构,如果不支持(看起来),我可以使用哪种数据结构来实现它?我是否被迫更改 numba 模式?

为了更好地澄清,训练/测试的示例如下:

[[(0, 1), (1, 1), (2, 1), (3, 2), (4, 5)], [(0, 2), (1, 4), (2, 3), (3, 4), (4, 2)], [(0, 5), (1, 4), (2, 3), (3, 4), (4, 2)], [(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)], [(0, 0), (1, 9), (2, 8), (3, 9), (4, 8)], [(0, 5), (1, 4), (2, 3), (3, 4), (4, 2)]]
Run Code Online (Sandbox Code Playgroud)

jpp*_*jpp 5

nopython 模式支持元组列表吗?

是的,它确实。但是,正如您的错误消息所暗示的那样,不是嵌套列表。

我是否被迫更改 numba 模式?

不,你不是。


您可以轻松地将元组列表转换L为常规 NumPy 数组:

L_arr = np.array(L)
Run Code Online (Sandbox Code Playgroud)

这是一个演示以及您可以如何自行测试:

from numba import jit

L = [[(0, 1), (1, 1), (2, 1), (3, 2), (4, 5)], [(0, 2), (1, 4), (2, 3), (3, 4), (4, 2)],
     [(0, 5), (1, 4), (2, 3), (3, 4), (4, 2)], [(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)],
     [(0, 0), (1, 9), (2, 8), (3, 9), (4, 8)], [(0, 5), (1, 4), (2, 3), (3, 4), (4, 2)]]

L_arr = np.array(L)

@jit(nopython=True)
def foo(x):
    return x
Run Code Online (Sandbox Code Playgroud)

这样L会出现错误:

print(foo(L))

LoweringError: Failed at nopython (nopython mode backend)
reflected list(reflected list((int64 x 2))): unsupported nested memory-managed object
Run Code Online (Sandbox Code Playgroud)

使用L_arr,您将获得一个形状为 的 3 维 NumPy 数组(6, 5, 2)

print(foo(L_arr))

array([[[0, 1],
        [1, 1],
        [2, 1],
        [3, 2],
        [4, 5]],
        ...
       [[0, 5],
        [1, 4],
        [2, 3],
        [3, 4],
        [4, 2]]])
Run Code Online (Sandbox Code Playgroud)

然后,您可能希望重构逻辑,以便更有效地使用 NumPy 数组而不是嵌套的元组列表。

  • 问题解决了,非常感谢。如果有人关注此线程,我想添加信息。当需要迭代 numpy 数组时,显然 Numba 的 nopython 仅通过使用索引来工作(例如,不要执行“for elem in array”,而是“for i in range(array.shape[0])” ”。 (2认同)