对于大型数组,是否有比 np.isin 更快的方法?

SY *_*eon 2 python optimization numpy pandas isin

对于大数组(n>1e8),有没有比np.isin检查是否有相同元素更快的方法?

我尝试过几种方法,例如 pandas isin, cython 但所有这些都比np.isin

示例:(测试一维数组的每个元素是否也存在于第二个数组中)

num = int(1e8)
a = np.random.rand(int(num))
b = np.random.rand(int(num))

ref=time.time()
ainb = np.isin(a,b)
print(a[ainb])
print(time.time()-ref,'sec')

Run Code Online (Sandbox Code Playgroud)
>>> [0.23591019 0.46102523]
>>> 65.45570135116577 sec
Run Code Online (Sandbox Code Playgroud)

nor*_*ok2 6

如果您需要(针对您的用例)但可能更快的替换np.isin(),您可以使用 Pythonset()进行检查并加速 Numba 中的显式循环:

import numpy as np
import numba as nb


@nb.jit
def is_in_set_nb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in range(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)
Run Code Online (Sandbox Code Playgroud)

请注意,有一些(便宜的)额外代码可以使其适用于 N 维数组,如果您只需要 1D,则可以省略这些代码。

通过添加进一步的并行化甚至可以使速度更快:

import numpy as np
import numba as nb


@nb.jit(parallel=True)
def is_in_set_pnb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in nb.prange(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)
Run Code Online (Sandbox Code Playgroud)

np.isin()这比、set()交集和没有 Numba 加速的解决方案快得多is_in_set()

def is_in_set(a, b):
    set_b = set(b)
    return np.array([x in set_b for x in a])
Run Code Online (Sandbox Code Playgroud)

输入大小为一千万个元素时:

n = 10 ** 7
k = n // 3
np.random.seed(0)
# note: I used `int`s because I wanted to be able to control the collisions
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 3.94 s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 814 ms per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 740 ms per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 7.69 s per loop
%timeit set(a).intersection(b)  # not a drop-in replacement
# 1 loop, best of 3: 6.79 s per loop
%timeit set(a) & set(b)  # not a drop-in replacement
# 1 loop, best of 3: 8.98 s per loop
Run Code Online (Sandbox Code Playgroud)

并具有一亿个元素(最后两种方法最终填满了所有内存,因此被省略):

n = 10 ** 8
k = n // 3
np.random.seed(0)
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 1min 4s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 13.1 s per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 11.4 s per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 2min 5s per loop
Run Code Online (Sandbox Code Playgroud)

a为较小的输入添加更多计时,但和的长度的所有组合b

funcs = np.isin, is_in_set_nb, is_in_set_pnb
sep = '    '
print(f'({"n=len(a)":>9s},{"m=len(b)":>9s})', end=sep)
for func in funcs:
    print(f'{func.__name__:15s}', end=sep)
print()
I, J = 7, 7
for i in range(I):
    for j in range(J):
        n = 10 ** i
        m = 10 ** j
        a = np.random.randint(0, m * n, n)
        b = np.random.randint(0, m * n, m)
        print(f'({n:9d},{m:9d})', end=sep)
        for func in funcs:
            result = %timeit -q -o func(a, b)
            print(f'{result.best * 1e3:12.3f} ms', end=sep)
        print()
Run Code Online (Sandbox Code Playgroud)
import numpy as np
import numba as nb


@nb.jit
def is_in_set_nb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in range(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)
Run Code Online (Sandbox Code Playgroud)

这表明 Numba 和并行化对于较大的输入非常有利,而对于较小的输入则效率稍低。然而,它们np.isin()在上述大多数测试中仍然表现出色。