在二维数组上查找第 K 个最小元素(或中值)的最快算法?

5 python java arrays algorithm data-structures

我在相关主题上看到了很多 SO 主题,但没有一个提供有效的方法。

我想k-th在二维数组中找到最小的元素(或中位数),[1..M][1..N]其中每一行都按升序排序并且所有元素都是不同的。

我认为有O(M log MN)解决方案,但我不知道实施。(中位数的中位数或使用具有线性复杂性的分区是一些方法,但不再知道......)。

这是一个旧的谷歌面试问题,可以在这里搜索。

但现在我想提示或描述最有效的算法最快的算法)。

我也在这里读过一篇论文,但我不明白。

更新 1:此处找到一种解决方案,但当维度为奇数时。

bti*_*lly 5

所以要解决这个问题,它有助于解决一个稍微不同的问题。我们想知道每行中第 k 个截止点所在位置的上限/下限。那么我们就可以通过,验证下界及以下的事物数<k,上界及以下的事物数>k,并且它们之间只有一个值。

我已经提出了一种策略,可以在所有行中同时对这些边界进行二分搜索。作为二分搜索,它“应该”O(log(n))通过。每次通关O(m)总共涉及工作O(m log(n))次数。我把应该放在引号中,因为我没有证据证明它实际上需要O(log(n))通过。事实上,有可能在一行中过于激进,从其他行中发现所选的枢轴已关闭,然后不得不后退。但我相信它几乎没有退缩,实际上是O(m log(n)).

策略是跟踪下限、上限和中间的每一行。每次通过,我们都会对范围进行一系列加权,以降低、降低到中、从中到上、从上到结尾,权重是其中的事物数量,值是系列中的最后一个。然后我们在该数据结构中找到第 k 个值(按权重),并将其用作我们在每个维度中进行二分搜索的主元。

如果枢轴超出从下到上的范围,我们会通过在纠正错误的方向上加宽间隔来进行纠正。

当我们有正确的序列时,我们就有了答案。

有很多边缘情况,所以盯着完整的代码可能会有所帮助。

我还假设每一行的所有元素都是不同的。如果不是,您可能会陷入无限循环。(解决这意味着更多的边缘情况......)

import random

# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
    # This does quickselect for average O(len(pairs)).
    # Median of medians is deterministically the same, but a bit slower
    pivot = pairs[int(random.random() * len(pairs))][0]

    # Which side of our answer is the pivot on?
    weight_under_pivot = 0
    pivot_weight = 0
    for value, weight in pairs:
        if value < pivot:
            weight_under_pivot += weight
        elif value == pivot:
            pivot_weight += weight

    if weight_under_pivot + pivot_weight < k:
        filtered_pairs = []
        for pair in pairs:
            if pivot < pair[0]:
                filtered_pairs.append(pair)
        return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
    elif k <= weight_under_pivot:
        filtered_pairs = []
        for pair in pairs:
            if pair[0] < pivot:
                filtered_pairs.append(pair)
        return weighted_kth (k, filtered_pairs)
    else:
        return pivot

# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
    # The strategy is to discover the k'th value, and also discover where
    # that would be in each row.
    #
    # For each row we will track what we think the lower and upper bounds
    # are on where it is.  Those bounds start as the start and end and
    # will do a binary search.
    #
    # In each pass we will break each row into ranges from start to lower,
    # lower to mid, mid to upper, and upper to end.  Some ranges may be
    # empty.  We will then create a weighted list of ranges with the weight
    # being the length, and the value being the end of the list.  We find
    # where the k'th spot is in that list, and use that approximate value
    # to refine each range.  (There is a chance that a range is wrong, and
    # we will have to deal with that.)
    #
    # We finish when all of the uppers are above our k, all the lowers
    # one are below, and the upper/lower gap is more than 1 only when our
    # k'th element is in the middle.

    # Our data structure is simply [row, lower, upper, bound] for each row.
    data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
    is_search = True
    while is_search:
        pairs = []
        for row, lower, upper, bound in data:
            # Literal edge cases
            if 0 == upper:
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            elif lower == bound:
                pairs.append((row[lower], lower + 1))
            elif lower + 1 == upper: # No mid.
                pairs.append((row[lower], lower + 1))
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            else:
                mid = (upper + lower) // 2
                pairs.append((row[lower], lower + 1))
                pairs.append((row[mid], mid - lower))
                pairs.append((row[upper], upper - mid))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))

        pivot = weighted_kth(k, pairs)

        # Now that we have our pivot, we try to adjust our parameters.
        # If any adjusts we continue our search.
        is_search = False
        new_data = []
        for row, lower, upper, bound in data:
            # First cases where our bounds weren't bounds for our pivot.
            # We rebase the interval and either double the range.
            # - double the size of the range
            # - go halfway to the edge
            if 0 < lower and pivot <= row[lower]:
                is_search = True
                if pivot == row[lower]:
                    new_data.append((row, lower-1, min(lower+1, bound), bound))
                elif upper <= lower:
                    new_data.append((row, lower-1, lower, bound))
                else:
                    new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
            elif upper < bound and row[upper] <= pivot:
                is_search = True
                if pivot == row[upper]:
                    new_data.append((row, upper-1, upper+1, bound))
                elif lower < upper:
                    new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
                else:
                    new_data.append((row, upper, upper+1, bound))
            elif lower + 1 < upper:
                if upper == lower+2 and pivot == row[lower+1]:
                    new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
                else:
                    # We will split this interval.
                    is_search = True
                    mid = (upper + lower) // 2
                    if row[mid] < pivot:
                        new_data.append((row, mid, upper, bound))
                    elif pivot < row[mid] pivot:
                        new_data.append((row, lower, mid, bound))
                    else:
                        # We center our interval on the pivot
                        new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
            else:
                # We look like we found where the pivot would be in this row.
                new_data.append((row, lower, upper, bound))
        data = new_data # And set up the next search
    return pivot
Run Code Online (Sandbox Code Playgroud)


Nuc*_*man 5

已添加另一个答案以提供实际解决方案。由于评论中有相当多的兔子洞,因此保留了这个。


我相信最快的解决方案是 k-way 合并算法。它是一种O(N log K)将所有项目的K排序列表合并N为一个大小为单个排序列表的算法N

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

给出一个MxN列表。这最终是O(MNlog(M)). 但是,这是为了对整个列表进行排序。由于您只需要第一个K最小的项目而不是所有项目N*M,因此性能是O(Klog(M)). 这比您正在寻找的要好得多,假设O(K) <= O(M).

尽管这假设您已经N对 size 列表进行了排序M。如果您确实有Msize 的排序列表N,则可以通过更改循环数据的方式轻松处理(请参阅下面的伪代码),尽管这确实意味着性能O(K log(N))

k-way 合并只是将每个列表的第一项添加到堆或其他具有O(log N)插入和O(log N)查找思维的数据结构中。

k-way合并的伪代码看起来有点像这样:

  1. 对于每个排序列表,将第一个值插入到数据结构中,并通过某种方式确定该值来自哪个列表。IE:您可能会插入[value, row_index, col_index]到数据结构中,而不仅仅是value. 这也使您可以轻松处理列或行的循环。
  2. 从数据结构中删除最低值并附加到排序列表。
  3. 鉴于步骤#2 中的项目来自列表,I将列表中的下一个最低值添加I到数据结构中。IE:如果值是row 5 col 4 (data[5][4]). 然后,如果您将行用作列表,则下一个值将是row 5 col 5 (data[5][5]). 如果您使用的是列,则下一个值是row 6 col 4 (data[6][4])。插入此下一个数值为#1(即:数据结构,像你这样[value, row_index, col_index]
  4. 根据需要返回第 2 步。

根据您的需要,执行 2-4K次步骤。