已排序的 numpy 数组的交集

dsh*_*hin 6 numpy

我有一个排序的 numpy 数组列表。计算这些数组的排序交集的最有效方法是什么?

在我的应用程序中,我期望数组的数量小于 10^4,我期望单个数组的长度小于 10^7,我期望交集的长度接近 p*N,其中N 是最大数组的长度,其中 0.99 < p <= 1.0。这些数组是从磁盘加载的,如果它们一次不能全部装入内存,则可以分批加载。

一种快速而肮脏的方法是重复调用numpy.intersect1d(). 尽管intersect1d()没有利用数组已排序的事实,但这似乎效率低下。

B. *_* M. 1

由于intersect1d每次都对数组进行排序,因此效率很低。

在这里,您必须将交集和每个样本一起扫描以构建新的交集,这可以在线性时间内完成,并保持顺序。

此类任务通常必须通过低级例程手动调整。

这是一种方法numba

from numba import njit
import numpy as np

@njit
def drop_missing(intersect,sample):
    i=j=k=0
    new_intersect=np.empty_like(intersect)
    while i< intersect.size and j < sample.size:
            if intersect[i]==sample[j]: # the 99% case
                new_intersect[k]=intersect[i]
                k+=1
                i+=1
                j+=1
            elif intersect[i]<sample[j]:
                i+=1
            else : 
                j+=1
    return new_intersect[:k]  
Run Code Online (Sandbox Code Playgroud)

现在的样品:

n=10**7
ref=np.random.randint(0,n,n)  
ref.sort()

def perturbation(sample,k):
    rands=np.random.randint(0,n,k-1)
    rands.sort()
    l=np.split(sample,rands)
    return np.concatenate([a[:-1] for a in l])

samples=[perturbation(ref,100) for  _ in range(10)] #similar samples 
Run Code Online (Sandbox Code Playgroud)

并运行 10 个样品

def find_intersect(samples):
    intersect=samples[0]
    for sample in samples[1:]:
        intersect=drop_missing(intersect,sample)
    return intersect                

In [18]: %time u=find_intersect(samples)
Wall time: 307 ms

In [19]: len(u)
Out[19]: 9999009     
Run Code Online (Sandbox Code Playgroud)

这样看来,这项工作可以在大约 5 分钟内完成,超出了加载时间。