在 numba.jit(nopython=True)
函数内部,我正在计算数千个 numpy 数组(一维,整数数据类型)并将它们附加到列表中。问题是某些数组看起来相等,但我不需要重复项。所以我需要一种有效的方法来检查新数组是否已存在于列表中。
在Python中可以这样完成:
import numpy as np
import numba as nb
# @nb.jit(nopython=True)
def foo(n):
uniques = []
uniques_set = set()
for _ in range(n):
arr = np.random.randint(0, 2, 2)
arr_hashable = make_hashable(arr)
if not arr_hashable in uniques_set:
uniques_set.add(arr_hashable)
uniques.append(arr)
return uniques
Run Code Online (Sandbox Code Playgroud)
我尝试了两种方法来解决这个问题:
将数组转换为元组并将元组放入集合中。
def make_hashable(arr):
return tuple(arr)
Run Code Online (Sandbox Code Playgroud)
但不幸的是,直接元组构造在 nopython 模式下不能以这种方式工作。我也尝试过这种方式:
def make_hashable(arr):
res = ()
for n in arr:
res += (n,)
return res
Run Code Online (Sandbox Code Playgroud)
和我能想到的其他类似的解决方法,但它们都在 nopython 模式下失败并出现 TypeError。
将数组转换为字符串并将其放入集合中。
def make_hashable(arr):
return arr.tostring()
Run Code Online (Sandbox Code Playgroud)
还尝试了所有可能的方法将数组转换为字符串,但似乎 numba 目前不支持字符串转换
也许有不同的方法来检查(有效)数组是否已存在于列表中?我的 numba 版本是 0.44。多谢。
我有 numba 0.58,但我知道解决您的问题的唯一方法仍然是使用回调到对象模式来散列数组。像这样:
import numpy as np
import numba as nb
def make_hashable(arr):
return hash(arr.tobytes())
@nb.jit(nopython=True)
def foo(n):
uniques = []
uniques_set = set()
for _ in range(n):
arr = np.random.randint(0, 2, 2)
with nb.objmode(arr_hashable='intp'):
arr_hashable = make_hashable(arr)
if arr_hashable not in uniques_set:
uniques_set.add(arr_hashable)
uniques.append(arr)
return uniques
foo(100)
# => [array([0, 0]), array([0, 1]), array([1, 1]), array([1, 0])]
Run Code Online (Sandbox Code Playgroud)
编辑:
如果您想正确处理冲突,可以使用字典而不是将哈希值映射到先前看到的数组列表的集合。当然,代码有点复杂:
import numpy as np
import numba as nb
def make_hashable(arr):
return hash(arr.tobytes())
list_type = nb.types.ListType(int64[:])
@nb.jit(nopython=True)
def foo(n):
uniques = []
uniques_dict = nb.typed.Dict.empty(nb.int64, list_type)
for _ in range(n):
arr = np.random.randint(0, 2, 2)
with nb.objmode(arr_hashable='intp'):
arr_hashable = make_hashable(arr)
is_seen = False
seen_arrs = uniques_dict.get(arr_hashable)
if seen_arrs is not None:
for seen_arr in seen_arrs:
if np.array_equal(arr, seen_arr):
is_seen = True
break
if not is_seen:
if seen_arrs is None:
seen_arrs = nb.typed.List.empty_list(int64[:])
uniques_dict[arr_hashable] = seen_arrs
seen_arrs.append(arr)
uniques.append(arr)
return uniques
foo(100)
# => [array([1, 0]), array([0, 0]), array([0, 1]), array([1, 1])]
Run Code Online (Sandbox Code Playgroud)
我必须list_type
在外面定义,否则编译失败。例如,您可以使用总是返回 1 的错误哈希函数来测试代码。
希望未来 numba 能够支持,bytes
这一切都变得不必要了。已经有一张票:https://github.com/numba/numba/issues/5149
归档时间: |
|
查看次数: |
565 次 |
最近记录: |