如何检查一个序列是严格单调的还是有一个双方都严格单调的转折点?

Jac*_*ack 2 python numpy pandas

输入

l1=[1,3,5,6,7]
l2=[1,2,2,3,4]
l3=[5,4,3,2,1]
l4=[5,5,3,2,1]
l5=[1,2,3,4.1,3,2]
l6=[3,2,1,0.4,1,2,3]
l7=[1,2,10,4,8,9,2]
l8=[1,2,3,4,4,3,2,1]
l9=[-0.05701686,  0.57707936, -0.34602634, -0.02599778]
l10=[ 0.13556905,  0.45859   , -0.34602634, -0.09178798,  0.03044908]
l11=[-0.38643975, -0.09178798,  0.57707936, -0.05701686,  0.00649252]
Run Code Online (Sandbox Code Playgroud)

注意:序列中的值是浮点数。

预期的

  • 编写一个函数find_targeted_seq,该函数返回一个序列,无论是严格单调还是有一个双方都严格单调的转折点。例如,l1, l3, l5,l6是预期的。

尝试

nor*_*ok2 12

Python 的标准库和 NumPy 都没有特定的原语来解决这个任务。但是,NumPy 中的传统方法是使用np.diff().

要调查转折点,您可以分别使用np.argmin()np.argmax()

严格的单调性条件对应于:(np.all(np.diff(arr) > 0)增加)或np.all(np.diff(arr) < 0)(减少)。一个转折点(pivot)的要求,相当于找到那个转折点并分别检查序列是否单调。对于多个连续的最小值或最大值,如果使用遇到的第一个最小值或最大值并检查排除该最小值或最大值的左分支的单调性,这足以保证将正确识别两个连续的最小值或最大值。

因此,一个简单的实现如下:

import numpy as np


def find_targeted_seq_np(seq):
    diffs = np.diff(seq)
    incr = diffs > 0
    decr = diffs < 0
    if np.all(incr) or np.all(decr):
        return True
    maximum = np.argmax(seq)
    if np.all(incr[:maximum]) and np.all(decr[maximum:]):
        return True
    minimum = np.argmin(seq)
    if np.all(decr[:minimum]) and np.all(incr[minimum:]):
        return True
    return False
Run Code Online (Sandbox Code Playgroud)

(这与@DaniMesejo 的回答中的想法基本相同)。


另一种选择是使用 的组合np.diff()np.sign()np.count_nonzero()计算单调性发生变化的次数。如果这是 0 或 1,则序列有效。避免重复元素在符号变化的计数中是内置的,除非重复元素位于序列的开头或结尾,并且必须明确检查这种情况。这导致了一个非常简洁的解决方案:

import numpy as np


def find_targeted_seq_np2(seq):
    diffs = np.diff(seq)
    return \
        diffs[0] != 0 and diffs[-1] != 0  \
        and np.count_nonzero(np.diff(np.sign(diffs))) < 2
Run Code Online (Sandbox Code Playgroud)

(这与@yonatansc97's answer 中的想法基本相同,但没有np.isin()按照@DaniMesejo 的评论中的建议使用)。


或者,可以考虑使用显式循环。这具有显着更高的内存效率和更好的短路特性的优点:

def find_targeted_seq(seq):
    n = len(seq)
    changes = 0
    x = seq[1]
    last_x = seq[0]
    diff = x - last_x
    if diff > 0:
        monotonic = 1
    elif diff < 0:
        monotonic = -1
    else:  # diff == 0
        return False
    for i in range(1, n):
        x = seq[i]
        diff = x - last_x
        if diff == 0:
            return False
        elif (diff > 0 and monotonic == -1) or (diff < 0 and monotonic == 1):
            changes += 1
            monotonic = -monotonic
        if changes > 1:
            return False
        last_x = x
    return True
Run Code Online (Sandbox Code Playgroud)

此外,如果可以保证序列元素的类型稳定性,则可以通过 Numba 轻松加速:

import numba as nb


_find_targeted_seq_nb = nb.njit(find_targeted_seq)


def find_targeted_seq_nb(seq):
    return _find_targeted_seq_nb(np.array(seq))
Run Code Online (Sandbox Code Playgroud)

为了进行比较,这里报告了一个使用pandas(它提供了一些用于单调性检查的原语)和scipy.signal.argrelmin()/scipy.signal.argrelmax()用于寻找转折点(此代码与@DaniMesejo的答案中的代码基本相同)的实现,例如:

from scipy.signal import argrelmin, argrelmax
import pandas as pd


def is_strictly_monotonic_increasing(s):
    return s.is_unique and s.is_monotonic_increasing


def is_strictly_monotonic_decreasing(s):
    return s.is_unique and s.is_monotonic_decreasing


def find_targeted_seq_pd(lst):
    ser = pd.Series(lst)
    if is_strictly_monotonic_increasing(ser) or is_strictly_monotonic_decreasing(ser):
        return True
    minima, *_ = argrelmin(ser.values)
    if len(minima) == 1:  # only on minimum turning point
        idx = minima[0]
        return is_strictly_monotonic_decreasing(ser[:idx]) and is_strictly_monotonic_increasing(ser[idx:])
    maxima, *_ = argrelmax(ser.values)
    if len(maxima) == 1:  # only on maximum turning point
        idx = maxima[0]
        return is_strictly_monotonic_increasing(ser[:idx]) and is_strictly_monotonic_decreasing(ser[idx:])
    return False
Run Code Online (Sandbox Code Playgroud)

应用于给定输入的这些解决方案都确实给出了正确的结果:

data = (
    ((1, 3, 5, 6, 7), True),  # l1
    ((1, 2, 2, 3, 4), False),  # l2
    ((5, 4, 3, 2, 1), True),  # l3
    ((5, 5, 3, 2, 1), False),  # l4
    ((1, 2, 3, 4.1, 3, 2), True),  # l5
    ((3, 2, 1, 0.5, 1, 2), True),  # this value was added in addition to the existing ones
    ((3, 2, 1, 0.4, 1, 2, 3), True),  # l6
    ((1, 2, 10, 4, 8, 9, 2), False),  # l7
    ((1, 2, 3, 4, 4, 3, 2, 1), False),  # l8
    ((-0.05701686, 0.57707936, -0.34602634, -0.02599778), False),  # l9
    ((0.13556905, 0.45859, -0.34602634, -0.09178798, 0.03044908), False),  # l10
    ((-0.38643975, -0.09178798, 0.57707936, -0.05701686, 0.00649252), False),  # l11
)


funcs = find_targeted_seq_np, find_targeted_seq_np2, find_targeted_seq_pd, find_targeted_seq, find_targeted_seq_nb

for func in funcs:
    print(func.__name__, all(func(seq) == result for seq, result in data))
# find_targeted_seq_np True
# find_targeted_seq_np2 True
# find_targeted_seq_pd True
# find_targeted_seq True
# find_targeted_seq_nb True
Run Code Online (Sandbox Code Playgroud)

在时间上,建议数据的一些简单基准清楚地表明直接循环(有或没有 Numba 加速)是最快的。第二种 Numpy 方法比第一种 NumPy appraoch 快得多,而pandas基于-based的方法最慢:

for func in funcs:
    print(func.__name__)
    %timeit [func(seq) == result for seq, result in data]
    print()

# find_targeted_seq_np
# 1000 loops, best of 3: 530 µs per loop

# find_targeted_seq_np2
# 10000 loops, best of 3: 187 µs per loop

# find_targeted_seq_pd
# 100 loops, best of 3: 4.68 ms per loop

# find_targeted_seq
# 100000 loops, best of 3: 14.6 µs per loop

# find_targeted_seq_nb
# 10000 loops, best of 3: 19.9 µs per loop
Run Code Online (Sandbox Code Playgroud)

虽然在给定输入上,此测试数据的直接循环比基于 NumPy 的方法更快,但后者应该可以更好地扩展输入大小。在所有尺度上,Numba 方法都可能比 NumPy 方法更快。