lez*_*zaf 3 python performance signal-processing numpy
如标题所述,我想从一维数组中删除具有连续 零且长度等于或大于阈值的部分。
\n我生成了以下 MRE 中所示的解决方案:
\nimport numpy as np\n\nTHRESHOLD = 4\n\na = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))\n\nprint("Input: " + str(a))\n\n# Find the indices of the parts that meet threshold requirement\ngaps_above_threshold_inds = np.where(np.diff(np.nonzero(a)[0]) - 1 >= THRESHOLD)[0]\n\n# Delete these parts from array\nfor idx in gaps_above_threshold_inds:\n a = np.delete(a, list(range(np.nonzero(a)[0][idx] + 1, np.nonzero(a)[0][idx + 1])))\n \nprint("Output: " + str(a))\nRun Code Online (Sandbox Code Playgroud)\n输出:
\nInput: [1 1 0 1 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1]\nOutput: [1 1 0 1 1 1 0 0 0 1 1]\nRun Code Online (Sandbox Code Playgroud)\n有没有一种更简单、更有效的方法来在 numpy 数组上执行此操作?
\n根据 @mozway 评论,我正在编辑我的问题以提供更多信息。
\n基本上,问题域是:
\n正如我已经说过的,我的目标是删除长度阈值以上的零部分。
\n关于我对高效 numpy 处理的第一个担忧,@mathfux 的解决方案非常好,基本上就是我所寻找的。这就是我接受这个的原因。
\n然而,@J\xc3\xa9r\xc3\xb4me Richard 的方法回答了我的第二个问题,它提供了一个真正高性能的解决方案;如果数据集非常大,则非常有用。
\n感谢您的精彩回答!
\nnp.delete每次调用时都会创建一个新数组,效率非常低。更快的解决方案是将所有值存储在掩码/布尔数组中,然后立即过滤输入数组。然而,如果仅使用 Numpy 完成,这仍然可能需要纯 Python 循环。一个更简单、更快的解决方案是使用Numba(或 Cython)来做到这一点。这是一个实现:
import numpy as np
import numba as nb
@nb.njit('int_[:](int_[:], int_)')
def filterZeros(arr, threshold):
n = len(arr)
res = np.empty(n, dtype=arr.dtype)
count = 0
j = 0
for i in range(n):
if arr[i] == 0:
count += 1
else:
if count >= threshold:
j -= count
count = 0
res[j] = arr[i]
j += 1
if n > 0 and arr[n-1] == 0 and count >= threshold:
j -= count
return res[0:j]
a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))
a = filterZeros(a, 4)
print("Output: " + str(a))
Run Code Online (Sandbox Code Playgroud)
以下是我的机器上包含 100_000 个项目的随机二进制数组的结果:
Reference implementation: 5982 ms
Mozway's solution: 23.4 ms
This implementation: 0.11 ms
Run Code Online (Sandbox Code Playgroud)
因此,该解决方案比初始解决方案快约 54381 倍,比 Mozway 快212 倍。通过就地工作(销毁输入数组)并告诉 Numba 数组在内存中是连续的(使用::1而不是) ,代码甚至可以加快约 30% :。