使用 numba 更快的 numpy 替代字符串

Pau*_*Pau 5 python string performance numpy numba

np.isin我正在尝试实现更快的in版本numba,这是我到目前为止所拥有的:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def isin(a, b):
    out=np.empty(a.shape[0], dtype=nb.boolean)
    b = set(b)
    for i in nb.prange(a.shape[0]):
        if a[i] in b:
            out[i]=True
        else:
            out[i]=False
    return out
Run Code Online (Sandbox Code Playgroud)

对于数字来说它是有效的,如下例所示:

a = np.array([1,2,3,4])
b = np.array([2,4])

isin(a,b)
>>> array([False,  True, False,  True])
Run Code Online (Sandbox Code Playgroud)

而且它比以下更快np.isin

a = np.random.rand(20000)
b = np.random.rand(5000)

%time isin(a,b)
CPU times: user 3.96 ms, sys: 0 ns, total: 3.96 ms
Wall time: 1.05 ms

%time np.isin(a,b)
CPU times: user 11 ms, sys: 0 ns, total: 11 ms
Wall time: 8.48 ms
Run Code Online (Sandbox Code Playgroud)

但是,我想将此函数与包含字符串的数组一起使用。问题是,每当我尝试传递字符串数组时,numba都会抱怨它无法解释set()这些数据的操作。

a = np.array(['A','B','C','D'])
b = np.array(['B','D'])

isin(a,b)
Run Code Online (Sandbox Code Playgroud)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<class 'set'>) found for signature:
 
 >>> set(array([unichr x 1], 1d, C))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'set': File: numba/core/typing/setdecl.py: Line 20.
    With argument(s): '(array([unichr x 1], 1d, C))':
   No match.

During: resolving callee type: Function(<class 'set'>)
During: typing of call at /tmp/ipykernel_20582/4221597969.py (7)


File "../../../../tmp/ipykernel_20582/4221597969.py", line 7:
<source missing, REPL/exec in use?>
Run Code Online (Sandbox Code Playgroud)

有没有一种方法,比如指定签名,可以让我直接在字符串数组上使用它?

我知道我可以为每个字符串分配一个数值,但对于大型数组,我认为这将花费一段时间,并且会使整个过程比仅使用np.isin.

有任何想法吗?

Jér*_*ard 2

Numba 几乎不支持字符串(虽然bytes支持稍微好一些)。集合和字典的支持有一些严格的限制,并且是相当实验性的/新的。关于文档,尚不支持字符串集

集合必须严格同质:Numba 将拒绝任何包含不同类型对象的集合,即使类型兼容(例如, {1, 2.5} 会被拒绝,因为它包含 aint和 a float)。不支持在集合中使用引用计数类型,例如strings 。

您可以尝试使用二分搜索进行作弊。不幸的是,np.searchsorted尚未实现字符串类型的 Numpy 数组(尽管np.unique是)。我认为你可以自己实现二分搜索,但这最终很麻烦。我不确定这最终会更快,但我认为它应该是因为O(Ns Na log Nb))运行时间复杂性(具有唯一项目Ns的字符串长度的平均大小bNa中的项目数aNb中的唯一项目数b)。事实上, 的运行时间复杂度np.isinO(Ns (Na+Nb) log (Na+Nb))if 数组的大小相似并且O(Ns Na Nb)ifNb远小于Na。请注意,最好的理论运行时间复杂度是 AFAIK,这O(Ns (Na + Nb))要归功于具有良好哈希函数的哈希表(尝试也可以实现这一点,但它们实际上应该更慢,除非哈希函数不是很好)。

请注意,类型字典支持静态声明的字符串,但不支持动态字符串(这是静态字符串的实验性功能)。

另一个作弊(应该有效)是将字符串哈希存储为类型化字典的键,并将每个哈希与引用字符串位置的索引数组相关联,以b获取关联键的哈希。并行循环需要对a项目进行散列并获取具有此散列的字符串项目的位置b,以便您可以比较字符串。更快的实现是假设b字符串的哈希函数是完美的并且不存在冲突,因此您可以直接使用TypedDict[np.int64, np.int64]哈希表。您可以在构建时在运行时测试这个假设b。写这样的代码有点乏味。请注意,此实现最终可能不会比 Numpy 更快,因为 NumbaTypedDict目前相当慢......但是,在具有足够内核的处理器上并行执行应该更快。