是否有更好的方法来删除长度等于或大于阈值的连续零部分?

lez*_*zaf 3 python performance signal-processing numpy

问题陈述:

\n

如标题所述,我想从一维数组中删除具有连续 长度等于或大于的部分。

\n
\n

我的解决方案:

\n

我生成了以下 MRE 中所示的解决方案:

\n
import numpy as np\n\nTHRESHOLD = 4\n\na = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))\n\nprint("Input: " + str(a))\n\n# Find the indices of the parts that meet threshold requirement\ngaps_above_threshold_inds = np.where(np.diff(np.nonzero(a)[0]) - 1 >= THRESHOLD)[0]\n\n# Delete these parts from array\nfor idx in gaps_above_threshold_inds:\n    a = np.delete(a, list(range(np.nonzero(a)[0][idx] + 1, np.nonzero(a)[0][idx + 1])))\n    \nprint("Output: " + str(a))\n
Run Code Online (Sandbox Code Playgroud)\n

输出:

\n
Input:  [1 1 0 1 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1]\nOutput: [1 1 0 1 1 1 0 0 0 1 1]\n
Run Code Online (Sandbox Code Playgroud)\n
\n

问题:

\n

有没有一种更简单更有效的方法来在 numpy 数组上执行此操作?

\n
\n

编辑:

\n

根据 @mozway 评论,我正在编辑我的问题以提供更多信息。

\n

基本上,问题域是:

\n
    \n
  • 我有长度约为 20.000 个样本的一维信号
  • \n
  • 由于噪声,部分信号已归零
  • \n
  • 信号的其余部分具有非零值,范围为 ~[50, 250]
  • \n
  • 前导零和尾随零已被删除
  • \n
\n

正如我已经说过的,我的目标是删除长度阈值以上的零部分。

\n

更详细的问题:

\n
    \n
  • numpy 高效处理而言,上述方案是否有更好的解决方案?
  • \n
  • 高效的 信号处理技术而言,是否有比使用 numpy 更合适的方法来实现预期目标?
  • \n
\n
\n
\n

对答案的评论:

\n

关于我对高效 numpy 处理的第一个担忧,@mathfux 的解决方案非常好,基本上就是我所寻找的。这就是我接受这个的原因。

\n

然而,@J\xc3\xa9r\xc3\xb4me Richard 的方法回答了我的第二个问题,它提供了一个真正高性能的解决方案;如果数据集非常大,则非常有用。

\n

感谢您的精彩回答!

\n

Jér*_*ard 5

np.delete每次调用时都会创建一个新数组,效率非常低。更快的解决方案是将所有值存储在掩码/布尔数组中,然后立即过滤输入数组。然而,如果仅使用 Numpy 完成,这仍然可能需要纯 Python 循环。一个更简单、更快的解决方案是使用Numba(或 Cython)来做到这一点。这是一个实现:

import numpy as np
import numba as nb

@nb.njit('int_[:](int_[:], int_)')
def filterZeros(arr, threshold):
    n = len(arr)
    res = np.empty(n, dtype=arr.dtype)
    count = 0
    j = 0
    for i in range(n):
        if arr[i] == 0:
            count += 1
        else:
            if count >= threshold:
                j -= count
            count = 0
        res[j] = arr[i]
        j += 1
    if n > 0 and arr[n-1] == 0 and count >= threshold:
        j -= count
    return res[0:j]

a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))
a = filterZeros(a, 4)
print("Output: " + str(a))
Run Code Online (Sandbox Code Playgroud)

以下是我的机器上包含 100_000 个项目的随机二进制数组的结果:

Reference implementation: 5982    ms
Mozway's solution:          23.4  ms
This implementation:         0.11 ms
Run Code Online (Sandbox Code Playgroud)

因此,该解决方案比初始解决方案快约 54381 倍,比 Mozway 快212 倍。通过就地工作(销毁输入数组)并告诉 Numba 数组在内存中是连续的(使用::1而不是) ,代码甚至可以加快约 30% :

  • 使用 numba“就地”工作安全吗?RAM 中被忽略的部分会发生什么情况?会被垃圾收集器释放吗? (2认同)
  • @dankal444 在这种情况下,我认为它是安全的。事实上,写入的数组部分已经被读取(并且以后不会再次读取)。Numba 允许写入输入参数数组(就像 Numba 中实现的具有“out”参数的 Numpy 函数)。输出数组将是输入数组的 Numpy 子视图。AFAIK,视图由 Numpy 管理,尽管它们像任何 Python 对象一样进行引用计数。当没有视图或任何变量引用数组时,可以删除数组。AFAIK,允许返回子视图,并且 Numba *应该*正确更新对象。在实践中,它似乎有效。 (2认同)