需要帮助理解"平衡0-1矩阵"的动态规划方法?

Rut*_*mar 6 algorithm dynamic-programming

问题:我正在努力理解/可视化"动态规划 - 维基百科文章中的一种平衡0-1矩阵的动态规划方法".

维基百科链接:https://en.wikipedia.org/wiki/Dynamic_programming#A_type_of_balanced_0.E2.80.931_matrix

在处理多维数组时,我无法理解memoization的工作原理.例如,当尝试使用DP解决Fibonacci系列时,使用数组来存储先前的状态结果很容易,因为数组的索引值存储该状态的解决方案.

有人可以用更简单的方式解释DP方法的"0-1平衡矩阵"吗?

bti*_*lly 5

维基百科提供了一个蹩脚的解释和一个不理想的算法。但让我们以它为起点。

首先让我们采用回溯算法。与其将矩阵的单元格“按某种顺序”放置,不如让我们将所有内容放在第一行中,然后是第二行中的所有内容,然后是第三行中的所有内容,依此类推。显然这会奏效。

现在让我们稍微修改回溯算法。我们将一行一行地进行,而不是逐个单元格地进行。所以我们列出n choose n/2可能的行,其中一半是 0,一半是 1。然后有一个递归函数,看起来像这样:

def count_0_1_matrices(n, filled_rows=None):
    if filled_rows is None:
        filled_rows = []
    if some_column_exceeds_threshold(n, filled_rows):
        # Cannot have more than n/2 0s or 1s in any column
        return 0
    else:
        answer = 0
        for row in possible_rows(n):
            answer = answer + count_0_1_matrices(n, filled_rows + [row])
        return answer
Run Code Online (Sandbox Code Playgroud)

这是一种回溯算法,就像我们之前的算法一样。我们一次只处理整行,而不是单元格。

但是请注意,我们传递的信息比我们需要的要多。无需传入行的精确排列。我们需要知道的只是在剩余的每一列中需要多少个 1。所以我们可以让算法看起来更像这样:

def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    for i in still_needed:
        if i < 0:
            return 0

    # Did we reach the end of our matrix?
    if 0 == sum(still_needed):
        return 1

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, next_still_needed)

    return answer
Run Code Online (Sandbox Code Playgroud)

这个版本几乎就是维基百科版本中的递归函数。主要区别在于我们的基本情况是,在每一行完成后,我们什么都不需要,而维基百科会让我们对基本情况进行编码,以在每行完成后检查最后一行。

要从这个到自上而下的 DP,你只需要记住这个函数。在 Python 中,您可以通过定义然后添加@memoize装饰器来完成。像这样:

from functools import wraps

def memoize(func):
    cache = {}
    @wraps(func)
    def wrap(*args):
        if args not in cache:
            cache[args] = func(*args)
        return cache[args]
    return wrap
Run Code Online (Sandbox Code Playgroud)

但还记得我批评过维基百科算法吗?让我们开始改进它!第一个大的改进是这个。你是否注意到元素的顺序still_needed无关紧要,只是它们的值?因此,仅对元素进行排序将阻止您为每个排列单独进行计算。(可以有很多排列!)

@memoize
def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    for i in still_needed:
        if i < 0:
            return 0

    # Did we reach the end of our matrix?
    if 0 == sum(still_needed):
        return 1

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, sorted(next_still_needed))

    return answer
Run Code Online (Sandbox Code Playgroud)

那个小无伤大雅sorted看起来不重要,但是省了很多功夫!现在我们知道它still_needed总是排序的,我们可以简化对是否完成以及是否有任何负面影响的检查。另外,我们可以添加一个简单的检查来过滤掉列中有太多 0 的情况。

@memoize
def count_0_1_matrices(n, still_needed=None):
    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    if still_needed[-1] < 0:
        return 0

    total = sum(still_needed)
    if 0 == total:
        # We reached the end of our matrix.
        return 1
    elif total*2/n < still_needed[0]:
        # We have total*2/n rows left, but won't get enough 1s for a
        # column.
        return 0

    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, sorted(next_still_needed))

    return answer
Run Code Online (Sandbox Code Playgroud)

而且,假设您实施了possible_rows,这应该既有效又比维基百科提供的有效得多。

======

这是一个完整的工作实现。在我的机器上,它在 4 秒内计算了第 6 项。

#! /usr/bin/env python

from sys import argv
from functools import wraps

def memoize(func):
    cache = {}
    @wraps(func)
    def wrap(*args):
        if args not in cache:
            cache[args] = func(*args)
        return cache[args]
    return wrap

@memoize
def count_0_1_matrices(n, still_needed=None):
    if 0 == n:
        return 1

    if still_needed is None:
        still_needed = [int(n/2) for _ in range(n)]

    # Did we overrun any column?
    if still_needed[0] < 0:
        return 0

    total = sum(still_needed)
    if 0 == total:
        # We reached the end of our matrix.
        return 1
    elif total*2/n < still_needed[-1]:
        # We have total*2/n rows left, but won't get enough 1s for a
        # column.
        return 0
    # Calculate the answer by recursion.
    answer = 0
    for row in possible_rows(n):
        next_still_needed = [still_needed[i] - row[i] for i in range(n)]
        answer = answer + count_0_1_matrices(n, tuple(sorted(next_still_needed)))

    return answer

@memoize
def possible_rows(n):
    return [row for row in _possible_rows(n, n/2)]


def _possible_rows(n, k):
    if 0 == n:
        yield tuple()
    else:
        if k < n:
            for row in _possible_rows(n-1, k):
                yield tuple(row + (0,))
        if 0 < k:
            for row in _possible_rows(n-1, k-1):
                yield tuple(row + (1,))

n = 2
if 1 < len(argv):
    n = int(argv[1])

print(count_0_1_matrices(2*n)))
Run Code Online (Sandbox Code Playgroud)


גלע*_*רקן 2

您正在记住可能会重复的状态。在这种情况下需要记住的状态是向量(k是隐式的)。让我们看一下您链接到的示例之一。向量参数(长度n)中的每一对代表“尚未放置在该列中的零和一的数量”。

以左侧为例,其中向量为((1, 1) (1, 1) (1, 1) (1, 1)), when k = 2,导致它的赋值为1 0 1 0, k = 30 1 0 1, k = 4。但我们可以通过一组不同的分配达到相同的状态((1, 1) (1, 1) (1, 1) (1, 1)), k = 2,例如:0 1 0 1, k = 31 0 1 0, k = 4。如果我们记住状态 的结果((1, 1) (1, 1) (1, 1) (1, 1)),我们就可以避免再次重新计算该分支的递归。

如果有什么我可以更好地澄清的地方,请告诉我。

针对您的评论的进一步阐述:

维基百科的例子似乎是一个带有记忆的蛮力。该算法似乎试图枚举所有矩阵,但使用记忆来提前退出重复状态。我们如何枚举所有可能性?拿他们的例子来说,我们从尚未放置零和一的n = 4向量开始。[(2,2),(2,2),(2,2),(2,2)](由于向量中每个元组的总和为k,我们可以有一个更简单的向量,其中k和 1 或 0 的计数被保留。)

在递归的每个阶段k,我们都会枚举下一个向量的所有可能配置。如果状态存在于我们的哈希中,我们只需返回该键的值。否则,我们将向量分配为哈希中的新键(在这种情况下,此递归分支将继续)。

例如:

Vector                       [(2,2),(2,2),(2,2),(2,2)]

Possible assignments of 1's: [1 1 0 0], [1 0 1 0], [1 0 0 1] ... etc.

First branch:                [(2,1),(2,1),(1,2),(1,2)]
  is this vector a key in the hash?
  if yes, return value lookup
  else, assign this vector as a key in the hash where the value is the sum 
     of the function calls with the next possible vectors as their arguments
Run Code Online (Sandbox Code Playgroud)