有什么快速方法可以计算列表中 XOR b 大于 a AND b 的对的数量?

lea*_*ner 1 python algorithm time-complexity

我有一个数字数组,我想计算所有可能的对的组合,其中该对的异或运算大于与运算。

例子:

4,3,5,2
Run Code Online (Sandbox Code Playgroud)

可能的对是:

(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0

Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.
Run Code Online (Sandbox Code Playgroud)

这是我的程序:

def solve(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans
Run Code Online (Sandbox Code Playgroud)

时间复杂度为 O(n^2) ,但我的数组大小为 1 到 10^5 ,数组中的每个元素为 1 到 2^30 。那么如何降低这个程序的时间复杂度呢?

Pra*_*adi 5

这(有效地)使用与您相同的算法,因此它仍然是 O(n^2),但您可以使用 numpy 加速操作:

  • np.bitwise_xor对两个数组执行按位异或运算
  • np.bitwise_and对两个数组执行按位与运算
  • 为这些函数提供行向量和列向量允许 numpy 将结果广播到方阵。
  • 比较结果矩阵,我们得到一个布尔数组。我们只需要这个矩阵的下三角形。既然我们知道a ^ a == 0,我们可以简单地对整个数组求和并将其结果除以 2 即可得到答案。
import numpy as np

def npy(nums):
    xor_arr = np.bitwise_xor(nums, nums[:, None])
    and_arr = np.bitwise_and(nums, nums[:, None])

    return (xor_arr > and_arr).sum() // 2
Run Code Online (Sandbox Code Playgroud)

您还可以完全跳过 numpy 并numba在运行之前使用 JIT 编译您自己的代码。

import numba

@numba.njit
def nba(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans
Run Code Online (Sandbox Code Playgroud)

最后,这是我对戴夫算法的实现:

from collections import defaultdict
def new_alg(array):
    msb_num_count = defaultdict(int)
    for num in array:
        msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
        msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
    
    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = len(array)
    for group_len in msb_num_count.values():
        cnt += group_len * (len_all_groups - group_len)

    return cnt // 2
Run Code Online (Sandbox Code Playgroud)

并且,作为 numba 兼容的函数。我需要定义一个get_msb因为numba.njit不会处理内置的 python 函数

@numba.njit
def get_msb(num):
    msb = 0
    while num:
        msb += 1
        num = num >> 1
    return msb

@numba.njit
def new_alg_numba(array):
    msb_num_count = {}
    for num in array:
        msb = get_msb(num)
        if msb not in msb_num_count:
            msb_num_count[msb] = 0
        msb_num_count[msb] += 1

    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = len(array)
        
    for grp_len in msb_num_count.values():
        cnt += grp_len * (len_all_groups - grp_len)

    return cnt // 2
Run Code Online (Sandbox Code Playgroud)

比较运行时,我们发现 numba 方法明显快于 numpy 方法,后者本身比 python 中的循环更快。

Dave 给出的线性时间算法一开始就比 numpy 方法更快,并且对于输入 > ~1000 个元素,它开始比 numba 编译的代码更快。这种方法的 numba 编译版本甚至更快 —— 它超过了 numba 编译版本,速度loopy约为 100 个元素。

Kelly 对 Dave 算法的出色实现与我针对较大输入的实现的 numba 版本相当

在此输入图像描述

(您的实现被标记为“loopy”。图中的其他图例标签与我上面的答案中的函数名称相同。Kelly 的实现被标记为“kelly”)