有效计算总和接近 0 的所有数字组合

TLa*_*nni 7 python performance numpy pandas

我有以下 pandas dataframe df

column1 column2 list_numbers          sublist_column
x        y      [10,-6,1,-4]             
a        b      [1,3,7,-2]               
p        q      [6,2,-3,-3.2]             
Run Code Online (Sandbox Code Playgroud)

sublist_column 将包含“list_numbers”列中的数字,其总和为 0(0.5 是一个容差)我编写了以下代码。

def return_list(original_lst,target_sum,tolerance):
    memo=dict()
    sublist=[]
    for i, x in enumerate(original_lst):
    
        if memo_func(original_lst, i + 1, target_sum - x, memo,tolerance) > 0:
            sublist.append(x)
            target_sum -= x          
    return sublist  

def memo_func(original_lst, i, target_sum, memo,tolerance):
    
    if i >= len(original_lst):
        if target_sum <=tolerance and target_sum>=-tolerance:
            return 1
        else:
            return 0
    if (i, target_sum) not in memo:  
        c = memo_func(original_lst, i + 1, target_sum, memo,tolerance)
        c += memo_func(original_lst, i + 1, target_sum - original_lst[i], memo,tolerance)
        memo[(i, target_sum)] = c  
    return memo[(i, target_sum)]    
    
Run Code Online (Sandbox Code Playgroud)

然后我在“sublist_column”上使用“return_list”函数来填充结果。

target_sum = 0
tolerance=0.5

df['sublist_column']=df['list_numbers'].apply(lambda x: return_list(x,0,tolerance))
Run Code Online (Sandbox Code Playgroud)

以下将是结果数据框

column1 column2 list_numbers          sublist_column
x        y      [10,-6,1,-4]             [10,-6,-4]
a        b      [1,3,7,-2]               []
p        q      [6,2,-3,-3.2]            [6,-3,-3.2]  #sum is -0.2(within the tolerance)
Run Code Online (Sandbox Code Playgroud)

这给了我正确的结果,但速度非常慢(如果我使用spyder IDE,则需要 2 小时才能运行),因为我的数据帧大小大约有 50,000 行,并且“list_numbers”列中某些列表的长度超过 15 . 当“list_numbers”列中的列表中的元素数量大于 15 时,运行时间尤其受到影响。例如,以下列表需要近 15 分钟来处理

[-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,
-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,
-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]
Run Code Online (Sandbox Code Playgroud)

如何才能显着提高跑步时间?

Jér*_*ard 17

第 1 步:使用 Numba

根据评论,这似乎memo_func是主要瓶颈。您可以使用 Numba 来加快其执行速度。Numba 借助即时 (JIT) 编译器将 Python 代码编译为本机代码。JIT 能够执行尾部调用优化,并且本机函数调用比 CPython 快得多。这是一个例子:

import numba as nb

@nb.njit('(float64[::1], int64, float64, float64)')
def memo_func(original_arr, i, target_sum, tolerance):
    if i >= len(original_arr):
        if -tolerance <= target_sum <= tolerance:
            return 1
        return 0
    c = memo_func(original_arr, i + 1, target_sum, tolerance)
    c += memo_func(original_arr, i + 1, target_sum - original_arr[i], tolerance)
    return c

@nb.njit('(float64[::1], float64, float64)')
def return_list(original_arr, target_sum, tolerance):
    sublist = []
    for i, x in enumerate(original_arr):
        if memo_func(original_arr, np.int64(i + 1), target_sum - x,tolerance) > 0:
            sublist.append(x)
            target_sum -= x
    return sublist
Run Code Online (Sandbox Code Playgroud)

使用记忆似乎并不能加快结果速度,而且在 Numba 中实现起来有点麻烦。事实上,还有很多更好的方法来改进算法。

请注意,在调用函数之前,您需要转换 Numpy 数组中的列表:

lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
result = return_list(np.array(lst, np.float64), 0, tolerance)
Run Code Online (Sandbox Code Playgroud)

第二步:尾调用优化

调用许多函数来计算输入列表的正确部分效率不高。JIT 能够减少所有数量,但无法完全删除它们。当尾调用的深度很大时,您可以展开所有调用。例如,当有 6 个项目需要计算时,您可以使用以下代码:

if n-i == 6:
    c = 0
    s0 = target_sum
    v0, v1, v2, v3, v4, v5 = original_arr[i:]
    for s1 in (s0, s0 - v0):
        for s2 in (s1, s1 - v1):
            for s3 in (s2, s2 - v2):
                for s4 in (s3, s3 - v3):
                    for s5 in (s4, s4 - v4):
                        for s6 in (s5, s5 - v5):
                            c += np.int64(-tolerance <= s6 <= tolerance)
    return c
Run Code Online (Sandbox Code Playgroud)

这非常丑陋,但效率更高,因为 JIT 能够展开所有循环并生成非常快的代码。不过,这对于大型列表来说还不够。


第三步:更好的算法

对于大型输入列表,问题在于算法的指数复杂度。问题是这个问题看起来真的很像subset-sum的一个宽松变体,已知它是NP 完全的。众所周知,此类算法很难解决。迄今为止,解决 NP 完全问题的最佳精确实用算法是指数级的。简而言之,这意味着对于任何足够大的输入,没有已知的算法能够在合理的时间内(例如,小于人类的寿命)找到精确的解决方案。

话虽如此,有一些启发式方法和策略可以提高当前算法的复杂性。一种有效的方法是使用中间相遇算法。当应用于您的用例时,其想法是生成一大组目标总和,然后对它们进行排序,然后使用二分搜索来查找匹配值的数量。这在这里是可能的,因为-tolerance <= target_sum <= tolerancewheretarget_sum = partial_sum1 + partial_sum2相当于-tolerance + partial_sum2 <= partial_sum1 <= tolerance + partial_sum2.

不幸的是,生成的代码相当大而且不简单,但这肯定是尝试有效解决像这样的复杂问题所付出的成本。这里是:

# Generate all the target sums based on in_arr and put the result in out_sum
@nb.njit('(float64[::1], float64[::1], float64)', cache=True)
def gen_all_comb(in_arr, out_sum, target_sum):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        assert out_sum.size == 64
        v0, v1, v2, v3, v4, v5 = in_arr
        s0 = target_sum
        cur = 0
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                out_sum[cur] = s6
                                cur += 1
    else:
        assert out_sum.size % 2 == 0
        mid = out_sum.size // 2
        gen_all_comb(in_arr[1:], out_sum[:mid], target_sum)
        gen_all_comb(in_arr[1:], out_sum[mid:], target_sum - in_arr[0])

# Find the number of item in sorted_arr where:
# lower_bound <= item <= upper_bound
@nb.njit('(float64[::1], float64, float64)', cache=True)
def count_between(sorted_arr, lower_bound, upper_bound):
    assert lower_bound <= upper_bound
    lo_pos = np.searchsorted(sorted_arr, lower_bound, side='left')
    hi_pos = np.searchsorted(sorted_arr, upper_bound, side='right')
    return hi_pos - lo_pos

# Count all the target sums in:
# -tolerance <= all_target_sums(in_arr,sorted_target_sums)-s0 <= tolerance
@nb.njit('(float64[::1], float64[::1], float64, float64)', cache=True)
def multi_search(in_arr, sorted_target_sums, tolerance, s0):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        v0, v1, v2, v3, v4, v5 = in_arr
        c = 0
        for s1 in (s0, s0 + v0):
            for s2 in (s1, s1 + v1):
                for s3 in (s2, s2 + v2):
                    for s4 in (s3, s3 + v3):
                        for s5 in (s4, s4 + v4):
                            for s6 in (s5, s5 + v5):
                                lo = -tolerance + s6
                                hi = tolerance + s6
                                c += count_between(sorted_target_sums, lo, hi)
        return c
    else:
        c = multi_search(in_arr[1:], sorted_target_sums, tolerance, s0)
        c += multi_search(in_arr[1:], sorted_target_sums, tolerance, s0 + in_arr[0])
        return c

@nb.njit('(float64[::1], int64, float64, float64)', cache=True)
def memo_func(original_arr, i, target_sum, tolerance):
    n = original_arr.size
    remaining = n - i
    tail_size = min(max(remaining//2, 7), 16)

    # Tail call: for very small list (trivial case)
    if remaining <= 0:
        return np.int64(-tolerance <= target_sum <= tolerance)

    # Tail call: for big lists (better algorithm)
    elif remaining >= tail_size*2:
        partial_sums = np.empty(2**tail_size, dtype=np.float64)
        gen_all_comb(original_arr[-tail_size:], partial_sums, target_sum)
        partial_sums.sort()
        return multi_search(original_arr[-remaining:-tail_size], partial_sums, tolerance, 0.0)

    # Tail call: for medium-sized list (unrolling)
    elif remaining == 6:
        c = 0
        s0 = target_sum
        v0, v1, v2, v3, v4, v5 = original_arr[i:]
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                c += np.int64(-tolerance <= s6 <= tolerance)
        return c

    # Recursion
    c = memo_func(original_arr, i + 1, target_sum, tolerance)
    c += memo_func(original_arr, i + 1, target_sum - original_arr[i], tolerance)
    return c

@nb.njit('(float64[::1], float64, float64)', cache=True)
def return_list(original_arr, target_sum, tolerance):
    sublist = []
    for i, x in enumerate(original_arr):
        if memo_func(original_arr, np.int64(i + 1), target_sum - x,tolerance) > 0:
            sublist.append(x)
            target_sum -= x
    return sublist
Run Code Online (Sandbox Code Playgroud)

请注意,由于代码相当大,因此需要几秒钟的时间来编译。缓存应该有助于避免每次都重新编译它。


第四步:更好的算法

前面的代码计算匹配值的数量(存储在 中的值c)。这不是必需的,因为我们只想知道是否存在 1 个值(即memo_func(...) > 0)。因此,我们可以返回一个布尔值来定义是否找到一个值并优化算法,以便True在找到一些早期解决方案时直接返回。使用此方法可以跳过探索树的大部分(当存在许多可能的解决方案(例如随机数组)时,该方法特别有效)。

另一种优化是仅执行一次二分搜索(而不是两次),并在之前检查搜索的值是否可以在排序数组的最小-最大范围内找到(因此在应用昂贵的二分搜索之前跳过这种简单的情况)。由于之前的优化,这是可能的。

最后的优化是当生成的值multi_search太小/太大以至于我们可以确定不需要执行二分搜索时,尽早丢弃探索树的一部分。这可以通过计算搜索值的悲观过度近似来完成。这对于几乎没有解决方案的病态病例特别有用。

这是最终的实现

@nb.njit('(float64[::1], float64[::1], float64)', cache=True)
def gen_all_comb(in_arr, out_sum, target_sum):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        assert out_sum.size == 64
        v0, v1, v2, v3, v4, v5 = in_arr
        s0 = target_sum
        cur = 0
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                out_sum[cur] = s6
                                cur += 1
    else:
        assert out_sum.size % 2 == 0
        mid = out_sum.size // 2
        gen_all_comb(in_arr[1:], out_sum[:mid], target_sum)
        gen_all_comb(in_arr[1:], out_sum[mid:], target_sum - in_arr[0])

# Find the number of item in sorted_arr where:
# lower_bound <= item <= upper_bound
@nb.njit('(float64[::1], float64, float64)', cache=True)
def has_items_between(sorted_arr, lower_bound, upper_bound):
    if upper_bound < sorted_arr[0] or sorted_arr[sorted_arr.size-1] < lower_bound:
        return False
    lo_pos = np.searchsorted(sorted_arr, lower_bound, side='left')
    return lo_pos < sorted_arr.size and sorted_arr[lo_pos] <= upper_bound

# Count all the target sums in:
# -tolerance <= all_target_sums(in_arr,sorted_target_sums)-s0 <= tolerance
@nb.njit('(float64[::1], float64[::1], float64, float64)', cache=True)
def multi_search(in_arr, sorted_target_sums, tolerance, s0):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        v0, v1, v2, v3, v4, v5 = in_arr
        x3, x4, x5 = min(v3, 0), min(v4, 0), min(v5, 0)
        y3, y4, y5 = max(v3, 0), max(v4, 0), max(v5, 0)
        mini = sorted_target_sums[0]
        maxi = sorted_target_sums[sorted_target_sums.size-1]

        for s1 in (s0, s0 + v0):
            for s2 in (s1, s1 + v1):
                for s3 in (s2, s2 + v2):
                    # Prune the exploration tree early if a 
                    # larger range cannot be found.
                    lo = s3 + (x3 + x4 + x5 - tolerance)
                    hi = s3 + (y3 + y4 + y5 + tolerance)
                    if hi < mini or maxi < lo:
                        continue

                    for s4 in (s3, s3 + v3):
                        for s5 in (s4, s4 + v4):
                            for s6 in (s5, s5 + v5):
                                lo = -tolerance + s6
                                hi = tolerance + s6
                                if has_items_between(sorted_target_sums, lo, hi):
                                    return True
        return False
    return (
        multi_search(in_arr[1:], sorted_target_sums, tolerance, s0)
        or multi_search(in_arr[1:], sorted_target_sums, tolerance, s0 + in_arr[0])
    )

@nb.njit('(float64[::1], int64, float64, float64)', cache=True)
def memo_func(original_arr, i, target_sum, tolerance):
    n = original_arr.size
    remaining = n - i
    tail_size = min(max(remaining//2, 7), 13)

    # Tail call: for very small list (trivial case)
    if remaining <= 0:
        return -tolerance <= target_sum <= tolerance

    # Tail call: for big lists (better algorithm)
    elif remaining >= tail_size*2:
        partial_sums = np.empty(2**tail_size, dtype=np.float64)
        gen_all_comb(original_arr[-tail_size:], partial_sums, target_sum)
        partial_sums.sort()
        return multi_search(original_arr[-remaining:-tail_size], partial_sums, tolerance, 0.0)

    # Tail call: for medium-sized list (unrolling)
    elif remaining == 6:
        s0 = target_sum
        v0, v1, v2, v3, v4, v5 = original_arr[i:]
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                if -tolerance <= s6 <= tolerance:
                                    return True
        return False

    # Recursion
    return (
        memo_func(original_arr, i + 1, target_sum, tolerance)
        or memo_func(original_arr, i + 1, target_sum - original_arr[i], tolerance)
    )

@nb.njit('(float64[::1], float64, float64)', cache=True)
def return_list(original_arr, target_sum, tolerance):
    sublist = []
    for i, x in enumerate(original_arr):
        if memo_func(original_arr, np.int64(i + 1), target_sum - x,tolerance):
            sublist.append(x)
            target_sum -= x
    return sublist
Run Code Online (Sandbox Code Playgroud)

最终的实现旨在有效地计算病理情况(只有很少的非平凡解决方案,甚至没有像提供的大输入列表那样的解决方案)。但是,可以对其进行调整,以便在有许多解决方案(例如在大型随机均匀分布数组上)的情况下更快地计算,但代价是对病理情况的执行速度显着减慢。可以通过更改变量来设置此步长tail_size(对于具有更多解决方案的情况,较小的值更好)。


基准

这是测试的输入:

target_sum = 0
tolerance = 0.5

small_lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
big_lst = [-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]
random_lst = [-86145.13, -34783.33, 50912.99, -87823.73, 37537.52, -22796.4, 53530.74, 65477.91, -50725.36, -52609.35, 92769.95, 83630.42, 30436.95, -24347.08, -58197.95, 77504.44, 83958.08, -85095.73, -61347.26, -14250.65, 2012.91, 83969.32, -69356.41, 29659.23, 94736.29, 2237.82, -17784.34, 23079.36, 8059.84, 26751.26, 98427.46, -88735.07, -28936.62, 21868.77, 5713.05, -74346.18]
Run Code Online (Sandbox Code Playgroud)

均匀分布的随机列表有大量的解决方案,而提供的大列表没有。调整后的最终实现设置tail_sizemin(max(remaining//2, 7), 13)so ,可以更快地计算随机列表,但代价是大列表上的执行速度明显减慢。

这是我的机器上的小列表的计时:

Naive python algorithm:               173.45 ms
Naive algorithm using Numba:            7.21 ms
Tail call optimization + Numba:         0.33 ms
KellyBundy's implementation:            0.19 ms
Efficient algorithm + optim + Numba:    0.10 ms
Final implementation (tuned):           0.05 ms
Final implementation (default):         0.05 ms
Run Code Online (Sandbox Code Playgroud)

这是我的机器上的大型随机列表的计时(简单情况):

Efficient algorithm + optim + Numba:    209.61 ms
Final implementation (default):           4.11 ms
KellyBundy's implementation:              1.15 ms
Final implementation (tuned):             0.85 ms

Other algorithms are not shown here because they are too slow (see below)
Run Code Online (Sandbox Code Playgroud)

这是我的机器上的大列表的时间(具有挑战性的情况):

Naive python algorithm:               >20000 s    [estimation & out of memory]
Naive algorithm using Numba:            ~900 s    [estimation]
Tail call optimization + Numba:           42.61 s
KellyBundy's implementation:               0.671 s
Final implementation (tuned):              0.078 s
Efficient algorithm + optim + Numba:       0.051 s
Final implementation (default):            0.013 s
Run Code Online (Sandbox Code Playgroud)

因此,最终的实现在小输入上快了约 3500 倍,在大输入上快了 1_500_000 倍以上!它还使用更少的 RAM,因此实际上可以在便宜的 PC 上执行。

值得注意的是,使用多线程可以进一步减少执行时间,从而达到>5_000_000的速度,尽管在小输入上可能会更慢,并且会使代码有点复杂。


  • 多么好的答案啊! (7认同)