重新启动cumsum,如果cumsum大于值,则获取索引

use*_*861 5 python numpy pandas

说我有一段距离x=[1,2,1,3,3,2,1,5,1,1]

我想从x到达总和达到10的索引,在这种情况下,idx = [4,9]。

因此,满足条件后,cumsum重新启动。

我可以使用循环来完成此操作,但是对于大型阵列而言,循环速度很慢,我想知道是否可以以某种vectorized方式进行。

WeN*_*Ben 8

一个有趣的方法

sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)

(array([4, 9]),)
Run Code Online (Sandbox Code Playgroud)


piR*_*red 6

循环并不总是不好的(尤其是在需要循环时)。另外,没有工具或算法可以使此过程快于O(n)。因此,让我们进行一个良好的循环。

发电机功能

def cumsum_breach(x, target):
    total = 0
    for i, y in enumerate(x):
        total += y
        if total >= target:
            yield i
            total = 0

list(cumsum_breach(x, 10))

[4, 9]
Run Code Online (Sandbox Code Playgroud)

用Numba及时编译

Numba是需要安装的第三方库。
Numba可能会对支持哪些功能感到困惑。但这有效。
而且,正如Divakar指出的那样,Numba在数组上的表现更好

from numba import njit

@njit
def cumsum_breach_numba(x, target):
    total = 0
    result = []
    for i, y in enumerate(x):
        total += y
        if total >= target:
            result.append(i)
            total = 0

    return result

cumsum_breach_numba(x, 10)
Run Code Online (Sandbox Code Playgroud)

测试两个

因为我喜欢 ¯\_(?)_/¯

设定

np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()
Run Code Online (Sandbox Code Playgroud)

准确性

i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))

assert i0 == i1
Run Code Online (Sandbox Code Playgroud)

时间

%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))

582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)

Numba快100倍左右。

为了进行更真实的测试,我将列表转换为Numpy数组

%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))

43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)

这使它们达到大约。


Div*_*kar 6

这是一个带有numba和数组初始化的代码-

from numba import njit

@njit
def cumsum_breach_numba2(x, target, result):
    total = 0
    iterID = 0
    for i,x_i in enumerate(x):
        total += x_i
        if total >= target:
            result[iterID] = i
            iterID += 1
            total = 0
    return iterID

def cumsum_breach_array_init(x, target):
    x = np.asarray(x)
    result = np.empty(len(x),dtype=np.uint64)
    idx = cumsum_breach_numba2(x, target, result)
    return result[:idx]
Run Code Online (Sandbox Code Playgroud)

时机

包括@piRSquared's solutions并使用同一篇文章中的基准测试设置-

In [58]: np.random.seed([3, 1415])
    ...: x = np.random.randint(100, size=1000000).tolist()

# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop

# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop

# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop
Run Code Online (Sandbox Code Playgroud)

Numba:追加与数组初始化

为了更仔细地了解数组初始化的作用,这似乎是两个numba实现之间的最大区别,让我们将它们放在数组数据上,因为数组数据的创建本身就很耗时,而且它们都依赖于它-

In [62]: x = np.array(x)

In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop

In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop
Run Code Online (Sandbox Code Playgroud)

为了强制输出拥有自己的存储空间,我们可以制作一个副本。虽然不会大幅度改变事情-

In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop
Run Code Online (Sandbox Code Playgroud)