如何在Python中高效地获取总和为10或以下的所有组合

Pas*_*ten 5 python python-itertools

想象一下,您正尝试n=10在一定数量的区域(例如 )上分配一些固定资源(例如t=5)。我正在尝试有效地找出如何获得总和等于n或低于的所有组合。

例如10,0,0,0,0是好的,等等0,0,5,5,0,而3,3,3,3,3,3显然是错误的。

我到目前为止:

import itertools
t = 5
n = 10
r = [range(n+1)] * t
for x in itertools.product(*r): 
   if sum(x) <= n:          
       print x
Run Code Online (Sandbox Code Playgroud)

然而,这种蛮力方法的速度非常慢;一定会有更好的办法?

计时(1000 次迭代):

Default (itertools.product)           --- time: 40.90 s
falsetru recursion                    --- time:  3.63 s
Aaron Williams Algorithm (impl, Tony) --- time:  0.37 s
Run Code Online (Sandbox Code Playgroud)

Ton*_*ony 3

可能的方法如下。绝对会谨慎使用(几乎没有经过测试,但 n=10 和 t=5 的结果看起来合理)。

该方法涉及递归。生成具有 m 个元素(示例中为 5)的 n(示例中为 10)的分区的算法来自 Knuth 的第四卷。然后,如果需要,每个分区都会进行零扩展,并且所有不同的排列都是使用 Aaron Williams 的算法生成的,我已经在其他地方看到过该算法。两种算法都必须转换为 Python,这增加了错误出现的机会。Williams 算法需要一个链表,我必须用一个 2D 数组来伪造它,以避免编写链表类。

一个下午就过去了!

代码(注意你的n是我的maxn,你的t是我的p):

import itertools

def visit(a, m):
    """ Utility function to add partition to the list"""
    x.append(a[1:m+1])

def parts(a, n, m):
    """ Knuth Algorithm H, Combinatorial Algorithms, Pre-Fascicle 3B
        Finds all partitions of n having exactly m elements.
        An upper bound on running time is (3 x number of
        partitions found) + m.  Not recursive!      
    """
    while (1):
        visit(a, m)
        while a[2] < a[1]-1:
            a[1] -= 1
            a[2] += 1
            visit(a, m)
        j=3
        s = a[1]+a[2]-1
        while a[j] >= a[1]-1:
            s += a[j]
            j += 1
        if j > m:
            break
        x = a[j] + 1
        a[j] = x
        j -= 1
        while j>1:
            a[j] = x
            s -= x
            j -= 1
            a[1] = s

def distinct_perms(partition):
    """ Aaron Williams Algorithm 1, "Loopless Generation of Multiset
        Permutations by Prefix Shifts".  Finds all distinct permutations
        of a list with repeated items.  I don't follow the paper all that
        well, but it _possibly_ has a running time which is proportional
        to the number of permutations (with 3 shift operations for each  
        permutation on average).  Not recursive!
    """

    perms = []
    val = 0
    nxt = 1
    l1 = [[partition[i],i+1] for i in range(len(partition))]
    l1[-1][nxt] = None
    #print(l1)
    head = 0
    i = len(l1)-2
    afteri = i+1
    tmp = []
    tmp += [l1[head][val]]
    c = head
    while l1[c][nxt] != None:
        tmp += [l1[l1[c][nxt]][val]]
        c = l1[c][nxt]
    perms.extend([tmp])
    while (l1[afteri][nxt] != None) or (l1[afteri][val] < l1[head][val]):
        if (l1[afteri][nxt] != None) and (l1[i][val]>=l1[l1[afteri][nxt]][val]):
            beforek = afteri
        else:
            beforek = i
        k = l1[beforek][nxt]
        l1[beforek][nxt] = l1[k][nxt]
        l1[k][nxt] = head
        if l1[k][val] < l1[head][val]:
            i = k
        afteri = l1[i][nxt]
        head = k
        tmp = []
        tmp += [l1[head][val]]
        c = head
        while l1[c][nxt] != None:
            tmp += [l1[l1[c][nxt]][val]]
            c = l1[c][nxt]
        perms.extend([tmp])

    return perms

maxn = 10 # max integer to find partitions of
p = 5  # max number of items in each partition

# Find all partitions of length p or less adding up
# to maxn or less

# Special cases (Knuth's algorithm requires n and m >= 2)
x = [[i] for i in range(maxn+1)]
# Main cases: runs parts fn (maxn^2+maxn)/2 times
for i in range(2, maxn+1):
    for j in range(2, min(p+1, i+1)):
        m = j
        n = i
        a = [0, n-m+1] + [1] * (m-1) + [-1] + [0] * (n-m-1)
        parts(a, n, m)
y = []
# For each partition, add zeros if necessary and then find
# distinct permutations.  Runs distinct_perms function once
# for each partition.
for part in x:
    if len(part) < p:
        y += distinct_perms(part + [0] * (p - len(part)))
    else:
        y += distinct_perms(part)
print(y)
print(len(y))
Run Code Online (Sandbox Code Playgroud)