lea*_*ner 1 python algorithm time-complexity
我有一个数字数组,我想计算所有可能的对的组合,其中该对的异或运算大于与运算。
例子:
4,3,5,2
Run Code Online (Sandbox Code Playgroud)
可能的对是:
(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0
Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.
Run Code Online (Sandbox Code Playgroud)
这是我的程序:
def solve(array):
n = len(array)
ans = 0
for i in range(0, n):
p1 = array[i]
for j in range(i, n):
p2 = array[j]
if p1 ^ p2 > p1 & p2:
ans +=1
return ans
Run Code Online (Sandbox Code Playgroud)
时间复杂度为 O(n^2) ,但我的数组大小为 1 到 10^5 ,数组中的每个元素为 1 到 2^30 。那么如何降低这个程序的时间复杂度呢?
这(有效地)使用与您相同的算法,因此它仍然是 O(n^2),但您可以使用 numpy 加速操作:
np.bitwise_xor对两个数组执行按位异或运算np.bitwise_and对两个数组执行按位与运算a ^ a == 0,我们可以简单地对整个数组求和并将其结果除以 2 即可得到答案。import numpy as np
def npy(nums):
xor_arr = np.bitwise_xor(nums, nums[:, None])
and_arr = np.bitwise_and(nums, nums[:, None])
return (xor_arr > and_arr).sum() // 2
Run Code Online (Sandbox Code Playgroud)
您还可以完全跳过 numpy 并numba在运行之前使用 JIT 编译您自己的代码。
import numba
@numba.njit
def nba(array):
n = len(array)
ans = 0
for i in range(0, n):
p1 = array[i]
for j in range(i, n):
p2 = array[j]
if p1 ^ p2 > p1 & p2:
ans +=1
return ans
Run Code Online (Sandbox Code Playgroud)
最后,这是我对戴夫算法的实现:
from collections import defaultdict
def new_alg(array):
msb_num_count = defaultdict(int)
for num in array:
msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = len(array)
for group_len in msb_num_count.values():
cnt += group_len * (len_all_groups - group_len)
return cnt // 2
Run Code Online (Sandbox Code Playgroud)
并且,作为 numba 兼容的函数。我需要定义一个get_msb因为numba.njit不会处理内置的 python 函数
@numba.njit
def get_msb(num):
msb = 0
while num:
msb += 1
num = num >> 1
return msb
@numba.njit
def new_alg_numba(array):
msb_num_count = {}
for num in array:
msb = get_msb(num)
if msb not in msb_num_count:
msb_num_count[msb] = 0
msb_num_count[msb] += 1
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = len(array)
for grp_len in msb_num_count.values():
cnt += grp_len * (len_all_groups - grp_len)
return cnt // 2
Run Code Online (Sandbox Code Playgroud)
比较运行时,我们发现 numba 方法明显快于 numpy 方法,后者本身比 python 中的循环更快。
Dave 给出的线性时间算法一开始就比 numpy 方法更快,并且对于输入 > ~1000 个元素,它开始比 numba 编译的代码更快。这种方法的 numba 编译版本甚至更快 —— 它超过了 numba 编译版本,速度loopy约为 100 个元素。
Kelly 对 Dave 算法的出色实现与我针对较大输入的实现的 numba 版本相当

(您的实现被标记为“loopy”。图中的其他图例标签与我上面的答案中的函数名称相同。Kelly 的实现被标记为“kelly”)