获得 k 个排序数组的交集的最有效方法是什么?

ide*_*456 13 python algorithm python-3.x

给定 k 个排序数组,获取这些列表交集的最有效方法是什么

例子

输入:

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 
Run Code Online (Sandbox Code Playgroud)

输出:

[1,7]
Run Code Online (Sandbox Code Playgroud)

有一种方法可以根据我在 nlogk 时间的编程面试元素中读到的内容来获得 k 个排序数组的并集。我想知道是否有办法为十字路口做类似的事情

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]
    
    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))
    
    res = []
 
    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))
Run Code Online (Sandbox Code Playgroud)

编辑:显然这是一个我试图解决的算法问题,所以我不能使用任何内置函数,如设置交集等

Ray*_*ger 14

利用排序顺序

这是一种 O(n) 方法,除了一个迭代器和每个子列表一个值的基本要求之外,不需要任何特殊的数据结构或辅助内存:

from itertools import cycle

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            while (value := next(iterator)) == curr:
                pass
            pair[VALUE] = value
            curr, matches = value, 1
    except StopIteration:
        return result
Run Code Online (Sandbox Code Playgroud)

这是一个示例会话:

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]
Run Code Online (Sandbox Code Playgroud)

文字算法

该算法围绕迭代器、值对循环。如果一个值在所有对中都匹配,则它属于交集。如果一个值比目前看到的任何其他值都低,则当前迭代器前进。如果一个值大于目前看到的任何值,它就会成为新的目标并且匹配计数被重置为 1。当任何迭代器耗尽时,算法就完成了。

不依赖于内置函数

itertools.cycle()的使用是完全可选的。通过增加一个在末尾环绕的索引可以很容易地模拟它。

代替:

iterator, value = pair = next(pairs)
Run Code Online (Sandbox Code Playgroud)

你可以写:

pairnum += 1
if pairnum == n:
    pairnum = 0
iterator, value = pair = pairs[pairnum]    
Run Code Online (Sandbox Code Playgroud)

或更紧凑:

pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum] 
Run Code Online (Sandbox Code Playgroud)

重复值

如果要保留重复(如多重集),这是一个简单的修改,只需更改后面的四行result.append(curr)以从每个迭代器中删除匹配元素:

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            for i in range(n):
                iterator, value = pair = next(pairs)
                pair[VALUE] = next(iterator)
            curr, matches = pair[VALUE], 1
    except StopIteration:
        return result
Run Code Online (Sandbox Code Playgroud)


Oli*_*Oli 5

对的,这是可能的!我已经修改了您的示例代码来执行此操作。

我的答案假设您的问题是关于算法的 - 如果您想要使用sets 运行最快的代码,请参阅其他答案。

这保持了时间复杂度:和O(n log(k))之间的所有代码都是。主循环 ( ) 内有一个嵌套循环,但这仅运行一次,并且最初为 0,每次运行该内部循环后都会重置为 0,并且每次主循环迭代只能递增一次,因此内部循环总共不能执行比主循环更多的迭代。因此,由于内循环内的代码和 最多运行与外循环一样多的次数,并且外循环和 运行次数,算法为。if lowest != elem or ary != times_seen:unbench_all = FalseO(log(k))for unbenched in range(times_seen):times_seentimes_seenO(log(k))O(log(k))nO(n log(k))

该算法依赖于 Python 中元组的比较方式。它比较元组的第一项,如果它们相等,则比较第二项(即(x, a) < (x, b)当且仅当 时为 true a < b)。在此算法中,与问题中的示例代码不同,当从堆中弹出一个项目时,不一定会在同一迭代中再次推送该项目。由于我们需要检查所有子列表是否包含相同的数字,因此当一个数字从堆中弹出后,它的子列表就是我所说的“benched”,这意味着它不会被添加回堆中。这是因为我们需要检查其他子列表是否包含相同的项目,因此现在不需要添加该子列表的下一项。

如果一个数字确实在所有子列表中,那么堆将看起来像[(2,0),(2,1),(2,2),(2,3)],元组的所有第一个元素都相同,因此heappop将选择具有最低子列表索引的元素。这意味着第一个索引 0 将被弹出并times_seen增加到 1,然后索引 1 将被弹出并times_seen增加到 2 - 如果ary不等于,times_seen则该数字不在所有子列表的交集中。这导致了条件if lowest != elem or ary != times_seen:,它决定数字何时不应出现在结果中。else该语句的分支适用if于它仍然可能存在于结果中的情况。

布尔unbench_all值适用于需要从工作台中删除所有子列表的情况 - 这可能是因为:

  1. 已知当前号码不在子列表的交集内
  2. 已知位于子列表的交集

unbench_allis 时True,所有从堆中删除的子列表都将被重新添加。众所周知,这些是具有索引的项,range(times_seen)因为算法仅当它们具有相同编号时才会从堆中删除项目,因此它们必须按索引顺序删除,连续且从索引 0 开始,并且必须times_seen有他们。这意味着我们不需要存储已放置的子列表的索引,只需存储已放置的数字。

import heapq


def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res


if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))
Run Code Online (Sandbox Code Playgroud)

如果您愿意,可以这样编写等效算法:

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res

Run Code Online (Sandbox Code Playgroud)