任何加快itertool.product速度的方法

abh*_*sok 6 python numpy nested-loops python-itertools

我正在使用itertools.product来查找资产的可能权重,因为所有权重之和总计为100。

min_wt = 10
max_wt = 50
step = 10
nb_Assets = 5

weight_mat = []
for i in itertools.product(range(min_wt, (max_wt+1), step), repeat = nb_Assets):
    if sum(i) == 100:
        weight = [i]
        if np.shape(weight_mat)[0] == 0:
            weight_mat = weight
        else:
            weight_mat = np.concatenate((weight_mat, weight), axis = 0)
Run Code Online (Sandbox Code Playgroud)

上面的代码可以工作,但是因为它经历了不可接受的组合,所以它太慢了,示例[50,50,50,50,50]最终测试了3125个组合,而不是121个可能的组合。有什么方法可以在循环中添加“求和”条件以加快处理速度?

Ral*_*alf 2

比较所提供解决方案的性能:

import itertools
import timeit
import numpy as np


# original code from question
def f1():
    min_wt = 10
    max_wt = 50
    step = 10
    nb_assets = 5

    weight_mat = []
    for i in itertools.product(range(min_wt, (max_wt+1), step), repeat=nb_assets):
        if sum(i) == 100:
            weight = [i, ]
            if np.shape(weight_mat)[0] == 0:
                weight_mat = weight
            else:
                weight_mat = np.concatenate((weight_mat, weight), axis=0)

    return weight_mat


# code from question using list instead of numpy array
def f1b():
    min_wt = 10
    max_wt = 50
    step = 10
    nb_assets = 5

    weight_list = []
    for i in itertools.product(range(min_wt, (max_wt+1), step), repeat=nb_assets):
        if sum(i) == 100:
            weight_list.append(i)

    return weight_list


# calculating the last element of each tuple
def f2():
    min_wt = 10
    max_wt = 50
    step = 10
    nb_assets = 5

    weight_list = []
    for i in itertools.product(range(min_wt, (max_wt+1), step), repeat=nb_assets-1):
        the_sum = sum(i)
        if the_sum < 100:
            last_elem = 100 - the_sum
            if min_wt <= last_elem <= max_wt:
                weight_list.append(i + (last_elem, ))

    return weight_list


# recursive solution from user kaya3 (/sf/answers/4117669041/)
def constrained_partitions(n, k, min_w, max_w, w_step=1):
    if k < 0:
        raise ValueError('Number of parts must be at least 0')
    elif k == 0:
        if n == 0:
            yield ()
    else:
        for w in range(min_w, max_w+1, w_step):
            for p in constrained_partitions(n-w, k-1, min_w, max_w, w_step):
                yield (w,) + p

def f3():
    return list(constrained_partitions(100, 5, 10, 50, 10))


# recursive solution from user jdehesa (/sf/answers/4117679331/)
def make_weight_combs(min_wt, max_wt, step, nb_assets, req_wt):
    weights = range(min_wt, max_wt + 1, step)
    current = []
    yield from _make_weight_combs_rec(weights, nb_assets, req_wt, current)

def _make_weight_combs_rec(weights, nb_assets, req_wt, current):
    if nb_assets <= 0:
        yield tuple(current)
    else:
        # Discard weights that cannot possibly be used
        while weights and weights[0] + weights[-1] * (nb_assets - 1) < req_wt:
            weights = weights[1:]
        while weights and weights[-1] + weights[0] * (nb_assets - 1) > req_wt:
            weights = weights[:-1]
        # Add all possible weights
        for w in weights:
            current.append(w)
            yield from _make_weight_combs_rec(weights, nb_assets - 1, req_wt - w, current)
            current.pop()

def f4():
    return list(make_weight_combs(10, 50, 10, 5, 100))
Run Code Online (Sandbox Code Playgroud)

timeit我使用如下方式测试了这些功能:

print(timeit.timeit('f()', 'from __main__ import f1 as f', number=100))
Run Code Online (Sandbox Code Playgroud)

使用问题参数的结果:

# min_wt = 10
# max_wt = 50
# step = 10
# nb_assets = 5
0.07021828400320373       # f1 - original code from question
0.041302188008558005      # f1b - code from question using list instead of numpy array
0.009902548001264222      # f2 - calculating the last element of each tuple
0.10601829699589871       # f3 - recursive solution from user kaya3
0.03329997700348031       # f4 - recursive solution from user jdehesa
Run Code Online (Sandbox Code Playgroud)

如果我扩大搜索空间(减少步骤并增加资产):

# min_wt = 10
# max_wt = 50
# step = 5
# nb_assets = 6
7.6620834979985375        # f1 - original code from question
7.31425816299452          # f1b - code from question using list instead of numpy array
0.809070186005556         # f2 - calculating the last element of each tuple
14.88188026699936         # f3 - recursive solution from user kaya3
0.39385621099791024       # f4 - recursive solution from user jdehesa
Run Code Online (Sandbox Code Playgroud)

看起来f2f4是最快的(对于测试的数据大小)。