Python/NumPy首次出现子数组

Ven*_*tta 22 python arrays numpy

在Python或NumPy中,找出第一次出现的子阵列的最佳方法是什么?

例如,我有

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
Run Code Online (Sandbox Code Playgroud)

找出b出现在哪里的最快方法(运行时间)是什么?我理解字符串这非常容易,但对于列表或numpy ndarray呢?

非常感谢!

[编辑]我更喜欢numpy解决方案,因为从我的经验来看,numpy矢量化比Python列表理解要快得多.同时,大数组是巨大的,所以我不想把它转换成字符串; 这将是(太长).

dan*_*nem 18

我的第一个答案,但我认为这应该工作....

[x for x in xrange(len(a)) if a[x:x+len(b)] == b]
Run Code Online (Sandbox Code Playgroud)

返回模式开始的索引.

  • 这可能不是最快的解决方案,但+1是最简单的答案.这可能适合许多用户的需求,特别是如果numpy不可用. (2认同)
  • 在Python 3中使用`range`而不是`xrange`. (2认同)
  • 为了提高性能,您可以将“len(a)”替换为“len(a) - len(b) + 1” (2认同)

sen*_*rle 17

我假设你正在寻找一个特定于numpy的解决方案,而不是简单的列表理解或循环.一种方法可能是使用滚动窗口技术来搜索适当大小的窗口.这是rolling_window函数:

>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 
Run Code Online (Sandbox Code Playgroud)

然后你可以做类似的事情

>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

为了使它真正有用,你必须使用all以下方法沿轴1减少它:

>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

然后你可以使用它然而你使用布尔数组.获取索引的简单方法:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])
Run Code Online (Sandbox Code Playgroud)

对于列表,您可以调整其中一个滚动窗口迭代器以使用类似的方法.

对于非常大的数组和子数组,您可以像这样保存内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)
Run Code Online (Sandbox Code Playgroud)

另一方面,这可能会更慢.没有测试,不清楚多慢?看看Jamie对另一个记忆保存选项的回答,该选项必须检查误报.我想这两种解决方案之间的速度差异在很大程度上取决于输入的性质.


Jai*_*ime 17

基于卷积的方法,应该比stride_tricks基于方法的内存效率更高:

def find_subsequence(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq,
                                       subseq, mode='valid') == target)[0]
    # some of the candidates entries may be false positives, double check
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    return candidates[mask]
Run Code Online (Sandbox Code Playgroud)

对于非常大的数组,可能无法使用stride_tricks方法,但这个仍然有效:

haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
    haystack[idx:idx+100] = needle

In [3]: find_subsequence(haystack, needle)
Out[3]: 
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
       961100, 973481], dtype=int64)

In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True

In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop
Run Code Online (Sandbox Code Playgroud)


nor*_*ok2 14

编辑以包括更深入的讨论、更好的代码和更多的基准测试)


概括

对于原始速度和效率,可以使用一种经典算法的 Cython 或 Numba 加速版本(当输入分别是 Python 序列或 NumPy 数组时)。

推荐的方法是:

  • find_kmp_cy()用于Python序列(listtuple等)
  • find_kmp_nb() 对于 NumPy 数组

其他有效的方法是find_rk_cy()find_rk_nb()哪些是更有效的内存,但不能保证在线性时间内运行。

如果用Cython / Numba不到位,再次既find_kmp()find_rk()对于大多数使用情况良好的全方位解决方案,虽然在平均情况下,并为Python序列,幼稚的做法,以某种形式,特别是find_pivot(),可能会更快。对于 NumPy 数组,find_conv()(来自@Jaime 的回答)优于任何非加速的朴素方法。

(完整代码如下,这里那里。)


理论

这是计算机科学中的一个经典问题,被称为字符串搜索或字符串匹配问题。基于两个嵌套循环的朴素方法的O(n + m)平均计算复杂度为,但最坏的情况是O(n m)。多年来,已经开发了许多替代方法来保证更好的最坏情况性能。

在经典算法中,最适合通用序列的算法(因为它们不依赖于字母表)是:

最后一种算法的效率依赖于滚动散列的计算,因此可能需要一些额外的输入知识以获得最佳性能。最终,它最适合同类数据,例如数字数组。Python 中数值数组的一个显着例子当然是 N​​umPy 数组。

评论

  • 朴素的算法如此简单,使其适用于 Python 中具有不同运行时速度的不同实现。
  • 其他算法在可以通过语言技巧优化的方面不太灵活。
  • Python 中的显式循环可能是速度瓶颈,可以使用多种技巧在解释器之外执行循环。
  • Cython特别擅长加速通用 Python 代码的显式循环。
  • Numba特别擅长加速 NumPy 数组的显式循环。
  • 这是生成器的绝佳用例,因此所有代码都将使用它们而不是常规函数。

蟒序列(listtuple等)

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb()这是明确的,仅环纯Python,用Cython与Numba JIT编译分别实施。请注意forceobj=TrueNumba 版本中的 ,这是必需的,因为我们使用 Python 对象输入。
def find_loop(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
Run Code Online (Sandbox Code Playgroud)
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_loop_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
Run Code Online (Sandbox Code Playgroud)
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
Run Code Online (Sandbox Code Playgroud)
  • find_all()用推导式all()生成器替换内循环
def find_all(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if all(seq[i + j] == subseq[j] for j in range(m)):
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_slice()==切片后用直接比较替换内循环[]
def find_slice(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_mix()并在切片后find_mix2()用直接比较替换内部循环,但在第一个(和最后一个)字符上包含一两个额外的短路,这可能更快,因为使用 an切片比使用 a 切片快得多。==[]intslice()
def find_mix(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)
def find_mix2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                and seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_pivot()并使用子序列的第一项find_pivot2()将外循环替换为多次.index()调用,同时对内循环使用切片,最终在最后一项(构造的第一个匹配项)上进行额外的短路。多个.index()调用被包装在一个index_all()生成器中(它本身可能很有用)。
def index_all(seq, item, start=0, stop=-1):
    try:
        n = len(seq)
        if n > 0:
            start %= n
            stop %= n
            i = start
            while True:
                i = seq.index(item, i)
                if i <= stop:
                    yield i
                    i += 1
                else:
                    return
        else:
            return
    except ValueError:
        pass


def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)
def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)

基于 Knuth-Morris-Pratt (KMP) 算法

  • find_kmp()是该算法的简单 Python 实现。由于没有简单的循环或可以使用 a 进行切片的地方slice(),因此除了使用 Cython(Numba 将再次需要forceobj=True这会导致代码变慢)之外,没有太多的优化工作要做。
def find_kmp(seq, subseq):
    n = len(seq)
    m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    j = 1
    k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    i = j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
Run Code Online (Sandbox Code Playgroud)
  • find_kmp_cy() 是算法的 Cython 实现,其中索引使用 C int 数据类型,这导致代码速度更快。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_kmp_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    cdef Py_ssize_t j = 1
    cdef Py_ssize_t k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    cdef Py_ssize_t i = 0
    j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
Run Code Online (Sandbox Code Playgroud)

基于 Rabin-Karp (RK) 算法

  • find_rk()是一个纯 Python 实现,它依赖于 Python 的hash()哈希计算(和比较)。这种散列是通过一个简单的sum(). 然后通过减去hash()刚刚访问过的 itemseq[i - 1]的结果hash()并将新考虑的 item的结果相加,从先前的哈希计算翻转seq[i + m - 1]
def find_rk(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_rk_cy()是算法的 Cython 实现,其中索引使用适当的 C 数据类型,这会产生更快的代码。请注意,hash()截断了“基于主机位宽的返回值”。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_rk_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    cdef Py_ssize_t old_item, new_item
    for i in range(1, n - m + 1):
        old_item = hash(seq[i - 1])
        new_item = hash(seq[i + m - 1])
        curr_hash += new_item - old_item  # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i
Run Code Online (Sandbox Code Playgroud)

基准

上述函数在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return tuple(random.randint(0, k - 1) for _ in range(n))
Run Code Online (Sandbox Code Playgroud)
  • (几乎)天真的算法的最差输入
def gen_input_worst(n, k=-2):
    result = [0] * n
    result[k] = 1
    return tuple(result)
Run Code Online (Sandbox Code Playgroud)

subseq具有固定大小(32)。由于有如此多的替代方案,已进行了两个单独的分组,并省略了一些变化非常小且时序几乎相同的解决方案(即find_mix2()find_pivot2())。对于每组,测试两个输入。对于每个基准,都提供了完整图和最快方法的放大图。

天真随机

bm_full_naive_random bm_zoom_naive_random

天真无邪

bm_full_naive_worst bm_zoom_naive_worst

其他随机

bm_full_other_random bm_zoom_other_random

其他最差

bm_full_other_worst bm_zoom_other_worst

(完整代码可在此处获得。)


NumPy 数组

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb()这是明确的,仅环纯Python,用Cython与Numba JIT编译分别实施。前两个的代码与上面相同,因此省略。find_loop_nb()现在享受快速 JIT 编译。内部循环已在单独的函数中编写,因为它可以被重用find_rk_nb()(在 Numba 函数中调用 Numba 函数不会导致 Python 典型的函数调用惩罚)。
@nb.jit
def _is_equal_nb(seq, subseq, m, i):
    for j in range(m):
        if seq[i + j] != subseq[j]:
            return False
    return True


@nb.jit
def find_loop_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if _is_equal_nb(seq, subseq, m, i):
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_all()是与上述相同,同时find_slice()find_mix()find_mix2()几乎相同于上述,唯一不同的是,seq[i:i + m] == subseq现在的自变量np.all()np.all(seq[i:i + m] == subseq)

  • find_pivot()find_pivot2()分享与上面相同的想法,除了现在使用np.where()代替index_all()和需要在np.all()调用中包含数组相等性。

def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif np.all(seq[i:i + m] == subseq):
            yield i


def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif seq[i + m - 1] == subseq[m - 1] \
                and np.all(seq[i:i + m] == subseq):
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_rolling()通过滚动窗口表达循环,并使用np.all(). 这以创建大型临时对象为代价来矢量化所有循环,同时仍然大量应用朴素算法。(该方法来自@senderle answer)。
def rolling_window(arr, size):
    shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)


def find_rolling(seq, subseq):
    bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
    yield from np.mgrid[0:len(bool_indices)][bool_indices]
Run Code Online (Sandbox Code Playgroud)
  • find_rolling2()是上述内存效率稍高的变体,其中矢量化只是部分的,并且保留了一个显式循环(沿着预期的最短维度 - 的长度subseq)。(该方法也来自@senderle answer)。
def find_rolling2(seq, subseq):
    windows = rolling_window(seq, len(subseq))
    hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
    for i, x in enumerate(subseq):
        hits &= np.in1d(windows[:, i], [x])
    yield from hits.nonzero()[0]
Run Code Online (Sandbox Code Playgroud)

基于 Knuth-Morris-Pratt (KMP) 算法

  • find_kmp()与上面相同,而它find_kmp_nb()是一个直接的 JIT 编译。
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'
Run Code Online (Sandbox Code Playgroud)

基于 Rabin-Karp (RK) 算法

  • find_rk()与上面相同,除了再次seq[i:i + m] == subseq包含在np.all()调用中。

  • find_rk_nb()是上面的Numba加速版本。使用_is_equal_nb()之前定义的来明确确定匹配,而对于散列,它使用 Numba 加速sum_hash_nb()函数,其定义非常简单。

@nb.jit
def sum_hash_nb(arr):
    result = 0
    for x in arr:
        result += hash(x)
    return result


@nb.jit
def find_rk_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if _is_equal_nb(seq, subseq, m, 0):
        yield 0
    hash_subseq = sum_hash_nb(subseq)  # compute hash
    curr_hash = sum_hash_nb(seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])  # update hash
        if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
            yield i
Run Code Online (Sandbox Code Playgroud)
  • find_conv()使用伪 Rabin-Karp 方法,其中初始候选者使用np.dot()乘积散列并位于seq和之间的卷积subseqnp.where()。该方法是伪的,因为虽然它仍然使用散列来识别可能的候选对象,但它可能不被视为滚动散列(这取决于 的实际实现np.correlate()。此外,它需要创建一个与输入大小相同的临时数组。 (该方法来自@Jaime answer)。
def find_conv(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    yield from candidates[mask]
Run Code Online (Sandbox Code Playgroud)

基准

和以前一样,上述函数在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return np.random.randint(0, k, n)
Run Code Online (Sandbox Code Playgroud)
  • (几乎)天真的算法的最差输入
def gen_input_worst(n, k=-2):
    result = np.zeros(n, dtype=int)
    result[k] = 1
    return result
Run Code Online (Sandbox Code Playgroud)

subseq具有固定大小(32)。该图遵循与之前相同的方案,为方便起见总结如下。

由于有如此多的替代方案,已进行了两个单独的分组,并省略了一些变化非常小且时序几乎相同的解决方案(即find_mix2()find_pivot2())。对于每组,测试两个输入。对于每个基准,都提供了完整图和最快方法的放大图。

天真随机

bm_full_naive_random bm_zoom_naive_random

天真无邪

bm_full_naive_worst bm_zoom_naive_worst

其他随机

bm_full_other_random bm_zoom_other_random

其他最差

bm_full_other_worst bm_zoom_other_worst

(完整代码可在此处获得。)


HYR*_*YRY 9

你可以调用tostring()方法将数组转换为字符串,然后你可以使用快速字符串搜索.当你有许多子阵列要检查时,这种方法可能会更快.

import numpy as np

a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize
Run Code Online (Sandbox Code Playgroud)