ide*_*456 13 python algorithm python-3.x
给定 k 个排序数组,获取这些列表交集的最有效方法是什么
例子
输入:
[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
Run Code Online (Sandbox Code Playgroud)
输出:
[1,7]
Run Code Online (Sandbox Code Playgroud)
有一种方法可以根据我在 nlogk 时间的编程面试元素中读到的内容来获得 k 个排序数组的并集。我想知道是否有办法为十字路口做类似的事情
## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heapq.heappop(heap)
it = srtd_iters[ary]
res.append(elem)
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
Run Code Online (Sandbox Code Playgroud)
编辑:显然这是一个我试图解决的算法问题,所以我不能使用任何内置函数,如设置交集等
Ray*_*ger 14
这是一种 O(n) 方法,除了一个迭代器和每个子列表一个值的基本要求之外,不需要任何特殊的数据结构或辅助内存:
from itertools import cycle
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
while (value := next(iterator)) == curr:
pass
pair[VALUE] = value
curr, matches = value, 1
except StopIteration:
return result
Run Code Online (Sandbox Code Playgroud)
这是一个示例会话:
>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]
Run Code Online (Sandbox Code Playgroud)
该算法围绕迭代器、值对循环。如果一个值在所有对中都匹配,则它属于交集。如果一个值比目前看到的任何其他值都低,则当前迭代器前进。如果一个值大于目前看到的任何值,它就会成为新的目标并且匹配计数被重置为 1。当任何迭代器耗尽时,算法就完成了。
itertools.cycle()的使用是完全可选的。通过增加一个在末尾环绕的索引可以很容易地模拟它。
代替:
iterator, value = pair = next(pairs)
Run Code Online (Sandbox Code Playgroud)
你可以写:
pairnum += 1
if pairnum == n:
pairnum = 0
iterator, value = pair = pairs[pairnum]
Run Code Online (Sandbox Code Playgroud)
或更紧凑:
pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum]
Run Code Online (Sandbox Code Playgroud)
如果要保留重复(如多重集),这是一个简单的修改,只需更改后面的四行result.append(curr)
以从每个迭代器中删除匹配元素:
def intersection(data):
ITERATOR, VALUE = 0, 1
n = len(data)
result = []
try:
pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
pair = next(pairs)
curr = pair[VALUE] # Candidate is the largest value seen so far
matches = 1 # Number of pairs where the candidate occurs
while True:
iterator, value = pair = next(pairs)
while value < curr:
value = next(iterator)
pair[VALUE] = value
if value > curr:
curr, matches = value, 1
continue
matches += 1
if matches != n:
continue
result.append(curr)
for i in range(n):
iterator, value = pair = next(pairs)
pair[VALUE] = next(iterator)
curr, matches = pair[VALUE], 1
except StopIteration:
return result
Run Code Online (Sandbox Code Playgroud)
对的,这是可能的!我已经修改了您的示例代码来执行此操作。
我的答案假设您的问题是关于算法的 - 如果您想要使用set
s 运行最快的代码,请参阅其他答案。
这保持了时间复杂度:和O(n log(k))
之间的所有代码都是。主循环 ( ) 内有一个嵌套循环,但这仅运行一次,并且最初为 0,每次运行该内部循环后都会重置为 0,并且每次主循环迭代只能递增一次,因此内部循环总共不能执行比主循环更多的迭代。因此,由于内循环内的代码和 最多运行与外循环一样多的次数,并且外循环和 运行次数,算法为。if lowest != elem or ary != times_seen:
unbench_all = False
O(log(k))
for unbenched in range(times_seen):
times_seen
times_seen
O(log(k))
O(log(k))
n
O(n log(k))
该算法依赖于 Python 中元组的比较方式。它比较元组的第一项,如果它们相等,则比较第二项(即(x, a) < (x, b)
当且仅当 时为 true a < b
)。在此算法中,与问题中的示例代码不同,当从堆中弹出一个项目时,不一定会在同一迭代中再次推送该项目。由于我们需要检查所有子列表是否包含相同的数字,因此当一个数字从堆中弹出后,它的子列表就是我所说的“benched”,这意味着它不会被添加回堆中。这是因为我们需要检查其他子列表是否包含相同的项目,因此现在不需要添加该子列表的下一项。
如果一个数字确实在所有子列表中,那么堆将看起来像[(2,0),(2,1),(2,2),(2,3)]
,元组的所有第一个元素都相同,因此heappop
将选择具有最低子列表索引的元素。这意味着第一个索引 0 将被弹出并times_seen
增加到 1,然后索引 1 将被弹出并times_seen
增加到 2 - 如果ary
不等于,times_seen
则该数字不在所有子列表的交集中。这导致了条件if lowest != elem or ary != times_seen:
,它决定数字何时不应出现在结果中。else
该语句的分支适用if
于它仍然可能存在于结果中的情况。
布尔unbench_all
值适用于需要从工作台中删除所有子列表的情况 - 这可能是因为:
当unbench_all
is 时True
,所有从堆中删除的子列表都将被重新添加。众所周知,这些是具有索引的项,range(times_seen)
因为算法仅当它们具有相同编号时才会从堆中删除项目,因此它们必须按索引顺序删除,连续且从索引 0 开始,并且必须times_seen
有他们。这意味着我们不需要存储已放置的子列表的索引,只需存储已放置的数字。
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# the number of tims that the current number has been seen
times_seen = 0
# the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
lowest = heap[0][0] if heap else None
# collect results in nlogK time
while heap:
elem, ary = heap[0]
unbench_all = True
if lowest != elem or ary != times_seen:
if lowest == elem:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
else:
heapq.heappop(heap)
times_seen += 1
if times_seen == len(srtd_arys):
res.append(elem)
else:
unbench_all = False
if unbench_all:
for unbenched in range(times_seen):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
times_seen = 0
if heap:
lowest = heap[0][0]
return res
if __name__ == '__main__':
a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
a2 = [[1, 1], [1, 1, 2, 2, 3]]
for arys in [a1, a2]:
print(mergeArys(arys))
Run Code Online (Sandbox Code Playgroud)
如果您愿意,可以这样编写等效算法:
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heap[0]
lowest = elem
keep_elem = True
for i in range(len(srtd_arys)):
elem, ary = heap[0]
if lowest != elem or ary != i:
if ary != i:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
keep_elem = False
i -= 1
break
heapq.heappop(heap)
if keep_elem:
res.append(elem)
for unbenched in range(i+1):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
if len(heap) < len(srtd_arys):
heap = []
return res
Run Code Online (Sandbox Code Playgroud)