我有一个排序的 numpy 数组列表。计算这些数组的排序交集的最有效方法是什么?
在我的应用程序中,我期望数组的数量小于 10^4,我期望单个数组的长度小于 10^7,我期望交集的长度接近 p*N,其中N 是最大数组的长度,其中 0.99 < p <= 1.0。这些数组是从磁盘加载的,如果它们一次不能全部装入内存,则可以分批加载。
一种快速而肮脏的方法是重复调用numpy.intersect1d(). 尽管intersect1d()没有利用数组已排序的事实,但这似乎效率低下。
由于intersect1d每次都对数组进行排序,因此效率很低。
在这里,您必须将交集和每个样本一起扫描以构建新的交集,这可以在线性时间内完成,并保持顺序。
此类任务通常必须通过低级例程手动调整。
这是一种方法numba:
from numba import njit
import numpy as np
@njit
def drop_missing(intersect,sample):
i=j=k=0
new_intersect=np.empty_like(intersect)
while i< intersect.size and j < sample.size:
if intersect[i]==sample[j]: # the 99% case
new_intersect[k]=intersect[i]
k+=1
i+=1
j+=1
elif intersect[i]<sample[j]:
i+=1
else :
j+=1
return new_intersect[:k]
Run Code Online (Sandbox Code Playgroud)
现在的样品:
n=10**7
ref=np.random.randint(0,n,n)
ref.sort()
def perturbation(sample,k):
rands=np.random.randint(0,n,k-1)
rands.sort()
l=np.split(sample,rands)
return np.concatenate([a[:-1] for a in l])
samples=[perturbation(ref,100) for _ in range(10)] #similar samples
Run Code Online (Sandbox Code Playgroud)
并运行 10 个样品
def find_intersect(samples):
intersect=samples[0]
for sample in samples[1:]:
intersect=drop_missing(intersect,sample)
return intersect
In [18]: %time u=find_intersect(samples)
Wall time: 307 ms
In [19]: len(u)
Out[19]: 9999009
Run Code Online (Sandbox Code Playgroud)
这样看来,这项工作可以在大约 5 分钟内完成,超出了加载时间。
| 归档时间: |
|
| 查看次数: |
1018 次 |
| 最近记录: |