高效返回数组中第一个满足条件的值的索引

jpp*_*jpp 5 python arrays performance numpy pandas

我需要找到满足条件的 1d NumPy 数组或 Pandas 数字系列中第一个值的索引。数组很大,索引可能靠近数组的开头结尾,或者根本不满足条件。我无法提前判断哪个更有可能。如果不满足条件,则返回值应为-1。我考虑了几种方法。

尝试 1

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)
Run Code Online (Sandbox Code Playgroud)

但这通常太慢了,因为func(arr)整个数组上应用向量化函数而不是在满足条件时停止。具体来说,当条件在数组开始附近满足时,代价是昂贵的。

尝试 2

np.argmax稍微快一点,但无法识别何时从未满足条件:

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms
Run Code Online (Sandbox Code Playgroud)

np.argmax(arr > 1.0)返回0,当条件,即一个实例并不满足。

尝试 3

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)
Run Code Online (Sandbox Code Playgroud)

但是当条件在数组末尾附近满足时,这太慢了。大概这是因为生成器表达式从大量__next__调用中产生了昂贵的开销。

总是一种妥协还是有一种方法,对于 generic func,可以有效地提取第一个索引?

基准测试

对于基准测试,假设func在值大于给定常量时找到索引:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms
Run Code Online (Sandbox Code Playgroud)

jpp*_*jpp 6

numba

有了numba它可以优化这两个场景。从语法上讲,您只需要构造一个带有简单for循环的函数:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)
Run Code Online (Sandbox Code Playgroud)

Numba 通过 JIT(“及时”)编译代码和利用CPU 级优化来提高性能。一个常规的 for无环路@njit装饰通常会比你已经尝试了在条件满足后期的情况下的方法。

对于 Pandas 数字系列df['data'],您可以简单地将 NumPy 表示提供给 JIT 编译函数:

idx = get_first_index_nb(df['data'].values, 0.9)
Run Code Online (Sandbox Code Playgroud)

概括

由于numba允许将函数作为参数,并且假设传递的函数也可以进行 JIT 编译,因此您可以找到一种方法来计算满足任意 条件的第n个索引func

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)
Run Code Online (Sandbox Code Playgroud)

对于倒数第三值,您可以提供相反的 ,arr[::-1]并否定来自 的结果len(arr) - 1,这- 1是考虑 0 索引所必需的。

性能基准测试

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms
Run Code Online (Sandbox Code Playgroud)