通过所有相交条目的并集更新所有列表条目的最快方法

wil*_*il3 4 python arrays algorithm performance intersection

我正在寻找一种快速方法来遍历集合列表,并通过查找与列表中至少共享一个元素的任何其他元素的并集来扩展每个集合。

例如,假设我有四行数据,其中每一行对应于一组唯一元素

0, 5, 101
8, 9, 19, 21
78, 79
5, 7, 63, 64
Run Code Online (Sandbox Code Playgroud)

第一行和最后一行有相交元素 5,因此在执行操作后我想要并集

0, 5, 7, 63, 64, 101
8, 9, 19, 21
78, 79
0, 5, 7, 63, 64, 101
Run Code Online (Sandbox Code Playgroud)

现在,我几乎可以用两个循环来做到这一点:

def consolidate_list(arr):
    """
    arr (list) : A list of lists, where the inner lists correspond to sets of unique integers
    """
    arr_out = list()
    for item1 in arr:
        item_additional = list() # a list containing all overlapping elements
        for item2 in arr:
            if len(np.intersect1d(item1, item2)) > 0:
                item_additional.append(np.copy(item2))
        out_val = np.unique(np.hstack([np.copy(item1)] + item_additional)) # find union of all lists

        arr_out.append(out_val)
        
return arr_out
Run Code Online (Sandbox Code Playgroud)

这种方法的问题是它需要运行多次,直到输出停止变化。由于输入可能是锯齿状的(即每组元素的数量不同),因此我看不到矢量化此函数的方法。

tri*_*cot 5

这个问题是关于创建不相交的集合,所以我会使用并集查找方法。

现在Python并不是特别以速度快而闻名,但为了展示算法,这里是一个DisjointSet没有库的类的实现:

class DisjointSet:
    class Element:
        def __init__(self):
            self.parent = self
            self.rank = 0


    def __init__(self):
        self.elements = {}

    def find(self, key):
        el = self.elements.get(key, None)
        if not el:
            el = self.Element()
            self.elements[key] = el
        else: # Path splitting algorithm
            while el.parent != el:
                el, el.parent = el.parent, el.parent.parent
        return el

    def union(self, key=None, *otherkeys):
        if key is not None:
            root = self.find(key)
            for otherkey in otherkeys:
                el = self.find(otherkey)
                if el != root:
                    # Union by rank
                    if root.rank < el.rank:
                        root, el = el, root
                    el.parent = root
                    if root.rank == el.rank:
                        root.rank += 1

    def groups(self):
        result = { el: [] for el in self.elements.values() 
                          if el.parent == el }
        for key in self.elements:
            result[self.find(key)].append(key)
        return result
Run Code Online (Sandbox Code Playgroud)

以下是如何使用它来解决这个特定问题:

def solve(lists):
    disjoint = DisjointSet()
    for lst in lists:
        disjoint.union(*lst)
            
    groups = disjoint.groups()
    return [lst and groups[disjoint.find(lst[0])] for lst in lists]
Run Code Online (Sandbox Code Playgroud)

调用示例:

data = [
    [0, 5, 101],
    [8, 9, 19, 21],
    [],
    [78, 79],
    [5, 7, 63, 64]
]
result = solve(data)
Run Code Online (Sandbox Code Playgroud)

结果将是:

[[0, 5, 101, 7, 63, 64], [8, 9, 19, 21], [], [78, 79], [0, 5, 101, 7, 63, 64]]
Run Code Online (Sandbox Code Playgroud)

请注意,我在输入列表中添加了一个空列表,以说明此边界情况保持不变。

注意:有些库提供并集查找/不相交集功能,每个库的 API 略有不同,但我认为使用其中之一可以提供更好的性能。

  • 啊,谢谢你的解释。重新运行时输出变得相同。 (2认同)