tut*_*eri 2 optimization compilation python-3.x numba sorting-network
这段代码实现了排序网络,我正在尝试使用 Numba 来编译它们以提高性能。然而,每个函数的编译时间呈指数增长。总共有大约 60 个函数(下面仅显示 19 个示例),Numba 在太阳经历红巨星膨胀之前无法完成编译它们。
我怀疑问题在于 Numba 尝试在编译期间应用 -O2 等激进的优化标志,从而导致过度的复杂性和处理时间。
编辑:我发现 numba从环境变量 NUMBA_OPT 中获取他的优化级别,所以我将其设置0为
import os
os.environ["NUMBA_OPT"] = "0"
Run Code Online (Sandbox Code Playgroud)
但它什么也没做。
有没有办法指示 Numba 简单地生成这些函数的汇编代码,而不尝试进一步优化?或者有其他方法可以编译它吗?
import numba as nb
import numpy as np
# This function calculates the min and the max of his parameters.
@nb.njit(nb.types.UniTuple(nb.uint64, 2)(nb.uint64, nb.uint64),fastmath=True,inline='always')
def m(a: np.uint64, b: np.uint64) -> (np.uint64, np.uint64):
return min(a, b), max(a, b)
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_1(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
return a
print('Defining function 2')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_2(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[1] = m(a[0], a[1])
return a
print('Defining function 3')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_3(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[2] = m(a[0], a[2])
a[0], a[1] = m(a[0], a[1])
a[1], a[2] = m(a[1], a[2])
return a
print('Defining function 4')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_4(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[1], a[2] = m(a[1], a[2])
return a
print('Defining function 5')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_5(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[3] = m(a[0], a[3])
a[1], a[4] = m(a[1], a[4])
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[0], a[1] = m(a[0], a[1])
a[2], a[4] = m(a[2], a[4])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[2], a[3] = m(a[2], a[3])
return a
print('Defining function 6')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_6(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[5] = m(a[0], a[5])
a[1], a[3] = m(a[1], a[3])
a[2], a[4] = m(a[2], a[4])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[0], a[3] = m(a[0], a[3])
a[2], a[5] = m(a[2], a[5])
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
return a
print('Defining function 7')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_7(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[6] = m(a[0], a[6])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[0], a[2] = m(a[0], a[2])
a[1], a[4] = m(a[1], a[4])
a[3], a[6] = m(a[3], a[6])
a[0], a[1] = m(a[0], a[1])
a[2], a[5] = m(a[2], a[5])
a[3], a[4] = m(a[3], a[4])
a[1], a[2] = m(a[1], a[2])
a[4], a[6] = m(a[4], a[6])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
return a
print('Defining function 8')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_8(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[4], a[6] = m(a[4], a[6])
a[5], a[7] = m(a[5], a[7])
a[0], a[4] = m(a[0], a[4])
a[1], a[5] = m(a[1], a[5])
a[2], a[6] = m(a[2], a[6])
a[3], a[7] = m(a[3], a[7])
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[2], a[4] = m(a[2], a[4])
a[3], a[5] = m(a[3], a[5])
a[1], a[4] = m(a[1], a[4])
a[3], a[6] = m(a[3], a[6])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
return a
print('Defining function 9')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_9(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[3] = m(a[0], a[3])
a[1], a[7] = m(a[1], a[7])
a[2], a[5] = m(a[2], a[5])
a[4], a[8] = m(a[4], a[8])
a[0], a[7] = m(a[0], a[7])
a[2], a[4] = m(a[2], a[4])
a[3], a[8] = m(a[3], a[8])
a[5], a[6] = m(a[5], a[6])
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[4], a[5] = m(a[4], a[5])
a[7], a[8] = m(a[7], a[8])
a[1], a[4] = m(a[1], a[4])
a[3], a[6] = m(a[3], a[6])
a[5], a[7] = m(a[5], a[7])
a[0], a[1] = m(a[0], a[1])
a[2], a[4] = m(a[2], a[4])
a[3], a[5] = m(a[3], a[5])
a[6], a[8] = m(a[6], a[8])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
return a
print('Defining function 10')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_10(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[8] = m(a[0], a[8])
a[1], a[9] = m(a[1], a[9])
a[2], a[7] = m(a[2], a[7])
a[3], a[5] = m(a[3], a[5])
a[4], a[6] = m(a[4], a[6])
a[0], a[2] = m(a[0], a[2])
a[1], a[4] = m(a[1], a[4])
a[5], a[8] = m(a[5], a[8])
a[7], a[9] = m(a[7], a[9])
a[0], a[3] = m(a[0], a[3])
a[2], a[4] = m(a[2], a[4])
a[5], a[7] = m(a[5], a[7])
a[6], a[9] = m(a[6], a[9])
a[0], a[1] = m(a[0], a[1])
a[3], a[6] = m(a[3], a[6])
a[8], a[9] = m(a[8], a[9])
a[1], a[5] = m(a[1], a[5])
a[2], a[3] = m(a[2], a[3])
a[4], a[8] = m(a[4], a[8])
a[6], a[7] = m(a[6], a[7])
a[1], a[2] = m(a[1], a[2])
a[3], a[5] = m(a[3], a[5])
a[4], a[6] = m(a[4], a[6])
a[7], a[8] = m(a[7], a[8])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
return a
print('Defining function 11')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_11(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[9] = m(a[0], a[9])
a[1], a[6] = m(a[1], a[6])
a[2], a[4] = m(a[2], a[4])
a[3], a[7] = m(a[3], a[7])
a[5], a[8] = m(a[5], a[8])
a[0], a[1] = m(a[0], a[1])
a[3], a[5] = m(a[3], a[5])
a[4], a[10] = m(a[4], a[10])
a[6], a[9] = m(a[6], a[9])
a[7], a[8] = m(a[7], a[8])
a[1], a[3] = m(a[1], a[3])
a[2], a[5] = m(a[2], a[5])
a[4], a[7] = m(a[4], a[7])
a[8], a[10] = m(a[8], a[10])
a[0], a[4] = m(a[0], a[4])
a[1], a[2] = m(a[1], a[2])
a[3], a[7] = m(a[3], a[7])
a[5], a[9] = m(a[5], a[9])
a[6], a[8] = m(a[6], a[8])
a[0], a[1] = m(a[0], a[1])
a[2], a[6] = m(a[2], a[6])
a[4], a[5] = m(a[4], a[5])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[2], a[4] = m(a[2], a[4])
a[3], a[6] = m(a[3], a[6])
a[5], a[7] = m(a[5], a[7])
a[8], a[9] = m(a[8], a[9])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
return a
print('Defining function 12')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_12(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[8] = m(a[0], a[8])
a[1], a[7] = m(a[1], a[7])
a[2], a[6] = m(a[2], a[6])
a[3], a[11] = m(a[3], a[11])
a[4], a[10] = m(a[4], a[10])
a[5], a[9] = m(a[5], a[9])
a[0], a[1] = m(a[0], a[1])
a[2], a[5] = m(a[2], a[5])
a[3], a[4] = m(a[3], a[4])
a[6], a[9] = m(a[6], a[9])
a[7], a[8] = m(a[7], a[8])
a[10], a[11] = m(a[10], a[11])
a[0], a[2] = m(a[0], a[2])
a[1], a[6] = m(a[1], a[6])
a[5], a[10] = m(a[5], a[10])
a[9], a[11] = m(a[9], a[11])
a[0], a[3] = m(a[0], a[3])
a[1], a[2] = m(a[1], a[2])
a[4], a[6] = m(a[4], a[6])
a[5], a[7] = m(a[5], a[7])
a[8], a[11] = m(a[8], a[11])
a[9], a[10] = m(a[9], a[10])
a[1], a[4] = m(a[1], a[4])
a[3], a[5] = m(a[3], a[5])
a[6], a[8] = m(a[6], a[8])
a[7], a[10] = m(a[7], a[10])
a[1], a[3] = m(a[1], a[3])
a[2], a[5] = m(a[2], a[5])
a[6], a[9] = m(a[6], a[9])
a[8], a[10] = m(a[8], a[10])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[4], a[6] = m(a[4], a[6])
a[5], a[7] = m(a[5], a[7])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
return a
print('Defining function 13')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_13(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[12] = m(a[0], a[12])
a[1], a[10] = m(a[1], a[10])
a[2], a[9] = m(a[2], a[9])
a[3], a[7] = m(a[3], a[7])
a[5], a[11] = m(a[5], a[11])
a[6], a[8] = m(a[6], a[8])
a[1], a[6] = m(a[1], a[6])
a[2], a[3] = m(a[2], a[3])
a[4], a[11] = m(a[4], a[11])
a[7], a[9] = m(a[7], a[9])
a[8], a[10] = m(a[8], a[10])
a[0], a[4] = m(a[0], a[4])
a[1], a[2] = m(a[1], a[2])
a[3], a[6] = m(a[3], a[6])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[11], a[12] = m(a[11], a[12])
a[4], a[6] = m(a[4], a[6])
a[5], a[9] = m(a[5], a[9])
a[8], a[11] = m(a[8], a[11])
a[10], a[12] = m(a[10], a[12])
a[0], a[5] = m(a[0], a[5])
a[3], a[8] = m(a[3], a[8])
a[4], a[7] = m(a[4], a[7])
a[6], a[11] = m(a[6], a[11])
a[9], a[10] = m(a[9], a[10])
a[0], a[1] = m(a[0], a[1])
a[2], a[5] = m(a[2], a[5])
a[6], a[9] = m(a[6], a[9])
a[7], a[8] = m(a[7], a[8])
a[10], a[11] = m(a[10], a[11])
a[1], a[3] = m(a[1], a[3])
a[2], a[4] = m(a[2], a[4])
a[5], a[6] = m(a[5], a[6])
a[9], a[10] = m(a[9], a[10])
a[1], a[2] = m(a[1], a[2])
a[3], a[4] = m(a[3], a[4])
a[5], a[7] = m(a[5], a[7])
a[6], a[8] = m(a[6], a[8])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
return a
print('Defining function 14')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_14(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[10], a[11] = m(a[10], a[11])
a[12], a[13] = m(a[12], a[13])
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[4], a[8] = m(a[4], a[8])
a[5], a[9] = m(a[5], a[9])
a[10], a[12] = m(a[10], a[12])
a[11], a[13] = m(a[11], a[13])
a[0], a[4] = m(a[0], a[4])
a[1], a[2] = m(a[1], a[2])
a[3], a[7] = m(a[3], a[7])
a[5], a[8] = m(a[5], a[8])
a[6], a[10] = m(a[6], a[10])
a[9], a[13] = m(a[9], a[13])
a[11], a[12] = m(a[11], a[12])
a[0], a[6] = m(a[0], a[6])
a[1], a[5] = m(a[1], a[5])
a[3], a[9] = m(a[3], a[9])
a[4], a[10] = m(a[4], a[10])
a[7], a[13] = m(a[7], a[13])
a[8], a[12] = m(a[8], a[12])
a[2], a[10] = m(a[2], a[10])
a[3], a[11] = m(a[3], a[11])
a[4], a[6] = m(a[4], a[6])
a[7], a[9] = m(a[7], a[9])
a[1], a[3] = m(a[1], a[3])
a[2], a[8] = m(a[2], a[8])
a[5], a[11] = m(a[5], a[11])
a[6], a[7] = m(a[6], a[7])
a[10], a[12] = m(a[10], a[12])
a[1], a[4] = m(a[1], a[4])
a[2], a[6] = m(a[2], a[6])
a[3], a[5] = m(a[3], a[5])
a[7], a[11] = m(a[7], a[11])
a[8], a[10] = m(a[8], a[10])
a[9], a[12] = m(a[9], a[12])
a[2], a[4] = m(a[2], a[4])
a[3], a[6] = m(a[3], a[6])
a[5], a[8] = m(a[5], a[8])
a[7], a[10] = m(a[7], a[10])
a[9], a[11] = m(a[9], a[11])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[6], a[7] = m(a[6], a[7])
return a
print('Defining function 15')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_15(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[1], a[2] = m(a[1], a[2])
a[3], a[10] = m(a[3], a[10])
a[4], a[14] = m(a[4], a[14])
a[5], a[8] = m(a[5], a[8])
a[6], a[13] = m(a[6], a[13])
a[7], a[12] = m(a[7], a[12])
a[9], a[11] = m(a[9], a[11])
a[0], a[14] = m(a[0], a[14])
a[1], a[5] = m(a[1], a[5])
a[2], a[8] = m(a[2], a[8])
a[3], a[7] = m(a[3], a[7])
a[6], a[9] = m(a[6], a[9])
a[10], a[12] = m(a[10], a[12])
a[11], a[13] = m(a[11], a[13])
a[0], a[7] = m(a[0], a[7])
a[1], a[6] = m(a[1], a[6])
a[2], a[9] = m(a[2], a[9])
a[4], a[10] = m(a[4], a[10])
a[5], a[11] = m(a[5], a[11])
a[8], a[13] = m(a[8], a[13])
a[12], a[14] = m(a[12], a[14])
a[0], a[6] = m(a[0], a[6])
a[2], a[4] = m(a[2], a[4])
a[3], a[5] = m(a[3], a[5])
a[7], a[11] = m(a[7], a[11])
a[8], a[10] = m(a[8], a[10])
a[9], a[12] = m(a[9], a[12])
a[13], a[14] = m(a[13], a[14])
a[0], a[3] = m(a[0], a[3])
a[1], a[2] = m(a[1], a[2])
a[4], a[7] = m(a[4], a[7])
a[5], a[9] = m(a[5], a[9])
a[6], a[8] = m(a[6], a[8])
a[10], a[11] = m(a[10], a[11])
a[12], a[13] = m(a[12], a[13])
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[6] = m(a[4], a[6])
a[7], a[9] = m(a[7], a[9])
a[10], a[12] = m(a[10], a[12])
a[11], a[13] = m(a[11], a[13])
a[1], a[2] = m(a[1], a[2])
a[3], a[5] = m(a[3], a[5])
a[8], a[10] = m(a[8], a[10])
a[11], a[12] = m(a[11], a[12])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[10], a[11] = m(a[10], a[11])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
return a
print('Defining function 16')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_16(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[13] = m(a[0], a[13])
a[1], a[12] = m(a[1], a[12])
a[2], a[15] = m(a[2], a[15])
a[3], a[14] = m(a[3], a[14])
a[4], a[8] = m(a[4], a[8])
a[5], a[6] = m(a[5], a[6])
a[7], a[11] = m(a[7], a[11])
a[9], a[10] = m(a[9], a[10])
a[0], a[5] = m(a[0], a[5])
a[1], a[7] = m(a[1], a[7])
a[2], a[9] = m(a[2], a[9])
a[3], a[4] = m(a[3], a[4])
a[6], a[13] = m(a[6], a[13])
a[8], a[14] = m(a[8], a[14])
a[10], a[15] = m(a[10], a[15])
a[11], a[12] = m(a[11], a[12])
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[8] = m(a[6], a[8])
a[7], a[9] = m(a[7], a[9])
a[10], a[11] = m(a[10], a[11])
a[12], a[13] = m(a[12], a[13])
a[14], a[15] = m(a[14], a[15])
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[4], a[10] = m(a[4], a[10])
a[5], a[11] = m(a[5], a[11])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[12], a[14] = m(a[12], a[14])
a[13], a[15] = m(a[13], a[15])
a[1], a[2] = m(a[1], a[2])
a[3], a[12] = m(a[3], a[12])
a[4], a[6] = m(a[4], a[6])
a[5], a[7] = m(a[5], a[7])
a[8], a[10] = m(a[8], a[10])
a[9], a[11] = m(a[9], a[11])
a[13], a[14] = m(a[13], a[14])
a[1], a[4] = m(a[1], a[4])
a[2], a[6] = m(a[2], a[6])
a[5], a[8] = m(a[5], a[8])
a[7], a[10] = m(a[7], a[10])
a[9], a[13] = m(a[9], a[13])
a[11], a[14] = m(a[11], a[14])
a[2], a[4] = m(a[2], a[4])
a[3], a[6] = m(a[3], a[6])
a[9], a[12] = m(a[9], a[12])
a[11], a[13] = m(a[11], a[13])
a[3], a[5] = m(a[3], a[5])
a[6], a[8] = m(a[6], a[8])
a[7], a[9] = m(a[7], a[9])
a[10], a[12] = m(a[10], a[12])
a[3], a[4] = m(a[3], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[11], a[12] = m(a[11], a[12])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
return a
print('Defining function 17')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_17(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[11] = m(a[0], a[11])
a[1], a[15] = m(a[1], a[15])
a[2], a[10] = m(a[2], a[10])
a[3], a[5] = m(a[3], a[5])
a[4], a[6] = m(a[4], a[6])
a[8], a[12] = m(a[8], a[12])
a[9], a[16] = m(a[9], a[16])
a[13], a[14] = m(a[13], a[14])
a[0], a[6] = m(a[0], a[6])
a[1], a[13] = m(a[1], a[13])
a[2], a[8] = m(a[2], a[8])
a[4], a[14] = m(a[4], a[14])
a[5], a[15] = m(a[5], a[15])
a[7], a[11] = m(a[7], a[11])
a[0], a[8] = m(a[0], a[8])
a[3], a[7] = m(a[3], a[7])
a[4], a[9] = m(a[4], a[9])
a[6], a[16] = m(a[6], a[16])
a[10], a[11] = m(a[10], a[11])
a[12], a[14] = m(a[12], a[14])
a[0], a[2] = m(a[0], a[2])
a[1], a[4] = m(a[1], a[4])
a[5], a[6] = m(a[5], a[6])
a[7], a[13] = m(a[7], a[13])
a[8], a[9] = m(a[8], a[9])
a[10], a[12] = m(a[10], a[12])
a[11], a[14] = m(a[11], a[14])
a[15], a[16] = m(a[15], a[16])
a[0], a[3] = m(a[0], a[3])
a[2], a[5] = m(a[2], a[5])
a[6], a[11] = m(a[6], a[11])
a[7], a[10] = m(a[7], a[10])
a[9], a[13] = m(a[9], a[13])
a[12], a[15] = m(a[12], a[15])
a[14], a[16] = m(a[14], a[16])
a[0], a[1] = m(a[0], a[1])
a[3], a[4] = m(a[3], a[4])
a[5], a[10] = m(a[5], a[10])
a[6], a[9] = m(a[6], a[9])
a[7], a[8] = m(a[7], a[8])
a[11], a[15] = m(a[11], a[15])
a[13], a[14] = m(a[13], a[14])
a[1], a[2] = m(a[1], a[2])
a[3], a[7] = m(a[3], a[7])
a[4], a[8] = m(a[4], a[8])
a[6], a[12] = m(a[6], a[12])
a[11], a[13] = m(a[11], a[13])
a[14], a[15] = m(a[14], a[15])
a[1], a[3] = m(a[1], a[3])
a[2], a[7] = m(a[2], a[7])
a[4], a[5] = m(a[4], a[5])
a[9], a[11] = m(a[9], a[11])
a[10], a[12] = m(a[10], a[12])
a[13], a[14] = m(a[13], a[14])
a[2], a[3] = m(a[2], a[3])
a[4], a[6] = m(a[4], a[6])
a[5], a[7] = m(a[5], a[7])
a[8], a[10] = m(a[8], a[10])
a[3], a[4] = m(a[3], a[4])
a[6], a[8] = m(a[6], a[8])
a[7], a[9] = m(a[7], a[9])
a[10], a[12] = m(a[10], a[12])
a[5], a[6] = m(a[5], a[6])
a[7], a[8] = m(a[7], a[8])
a[9], a[10] = m(a[9], a[10])
a[11], a[12] = m(a[11], a[12])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[10], a[11] = m(a[10], a[11])
a[12], a[13] = m(a[12], a[13])
return a
print('Defining function 18')
@nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)
def sort_small_array_18(a: 'np.ndarray[np.uint64]') -> 'np.ndarray[np.uint64]':
a[0], a[1] = m(a[0], a[1])
a[2], a[3] = m(a[2], a[3])
a[4], a[5] = m(a[4], a[5])
a[6], a[7] = m(a[6], a[7])
a[8], a[9] = m(a[8], a[9])
a[10], a[11] = m(a[10], a[11])
a[12], a[13] = m(a[12], a[13])
a[14], a[15] = m(a[14], a[15])
a[16], a[17] = m(a[16], a[17])
a[0], a[2] = m(a[0], a[2])
a[1], a[3] = m(a[1], a[3])
a[4], a[12] = m(a[4], a[12])
a[5], a[13] = m(a[5], a[13])
a[6], a[8] = m(a[6], a[8])
a[9], a[11] = m(a[9], a[11])
a[14], a[16] = m(a[14], a[16])
a[15], a[17] = m(a[15], a[17])
a[0], a[14] = m(a[0], a[14])
a[1], a[16] = m(a[1], a[16])
a[2], a[15] = m(a[2], a[15])
a[3], a[17] = m(a[3], a[17])
a[0], a[6] = m(a[0], a[6])
a[1], a[10] = m(a[1], a[10])
a[2], a[9] = m(a[2], a[9
乍一看,时代并不是指数增长,而是平方增长。我们可以通过绘制编译时间与函数编号的关系图来看出这一点。
\n\n问题是行数也显着增加,并且与编译函数所花费的时间高度相关。排序网络的大小不断增加,O(n log\xc2\xb2 n)因此编译时间至少应增加同样多的时间。
事实上,某些编译步骤并不与输入代码大小线性增长,而是更多。诸如优化之类的一些步骤可以在O(n\xc2\xb2). 例如,寄存器分配是一个特别昂贵的步骤,尤其是在您的情况下。事实上,众所周知,如果我们想找到最佳解决方案,尽管编译器在实践中使用相对较快的启发式算法,那么它是 NP 难的。有关编译器算法复杂性的更多信息,请阅读这篇文章。O(n)就您而言,我预计大多数操作都会在一段时间内完成O(n log n)。这意味着O(v log(v)**3)函数编号的编译时间v。事实证明,当前结果与O(n log(n))编译时间(其中n是行数)非常匹配,证实了上面提到的预期编译时间复杂度。
最重要的是,Numba 的设计目的并不是特别快速地生成代码,因为 Numba 代码预计很小(与本机代码甚至 Cython 代码相反)。AFAIK,负责将 Python 代码转换为 LLVM 中间表示的代码的很大一部分是用 Python 完成的(因此它被解释并且相当慢)。
\n话虽这么说,一个问题是你使用inline=\'always\'告诉 Numba 手动内联函数本身。这会显着增加生成的 IR 代码的大小,从而增加编译时间。请注意,如果 LLVM 认为内联值得,则可以在没有编译标志的情况下执行内联。这个手动内联过程占用了我机器上的大部分编译时间。事实上,这里是有和没有它的执行时间:
最后一个(n\xc2\xb017)的时间从 3010 ms 减少到 1360 ms。
\n还有另一个问题:当您执行 naive 时a[1],Numba 会生成一段代码来访问 Numpy 数组中的数据。问题是 Numba 需要处理 Numpy 的环绕特性,即不保证为 1 的步长,并基于此提取 Numpy 数组的值。这会为这样一个在代码中重复大约 2000 次的基本原语操作生成重要的代码...问题是 Numba 不支持基本的 C 数组,所以我认为我们无法在 Numba 中完全消除这种开销(Cython 可以使用用户标志)禁用 Numpy/CPython 功能(例如环绕)。
尽管如此,我们仍然可以避免 Numba 生成大量重复代码,并将负担转移到更优化的 LLVM 上。一种解决方案是修改函数m,以便它根据索引执行数组读/写。这大大减少了复制代码的大小。这是修改后的部分代码:
import numba as nb\nimport numpy as np\n\n# This function calculates the min and the max of his parameters.\n@nb.njit((nb.uint64[:], nb.uint64, nb.uint64),fastmath=True)\ndef m(arr: \'np.ndarray[np.uint64]\', ia: np.uint64, ib: np.uint64):\n va = arr[ia]\n vb = arr[ib]\n arr[ia] = min(va, vb)\n arr[ib] = max(va, vb)\n\ndef run1(): \n @nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)\n def sort_small_array_1(a: \'np.ndarray[np.uint64]\') -> \'np.ndarray[np.uint64]\':\n return a\n\ndef run2(): \n print(\'Defining function 2\')\n @nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)\n def sort_small_array_2(a: \'np.ndarray[np.uint64]\') -> \'np.ndarray[np.uint64]\':\n m(a, 0, 1)\n return a\n\ndef run3():\n print(\'Defining function 3\')\n @nb.njit(nb.uint64[:](nb.uint64[:]),fastmath=True)\n def sort_small_array_3(a: \'np.ndarray[np.uint64]\') -> \'np.ndarray[np.uint64]\':\n m(a, 0, 2)\n m(a, 0, 1)\n m(a, 1, 2)\n return a\n\n# [...]\nRun Code Online (Sandbox Code Playgroud)\n以下是最终的时间安排:
\n\n除了最快的函数之外,几乎所有函数的编译时间都快了大约10 倍(例如,最后一个函数为 3010 毫秒,而最后一个函数为 295 毫秒)。
\n由于 LLVM 可能会生成次优代码,因此需要检查结果代码的性能。请注意,提前编译更适合此类用例,因为编译只需完成一次。事实上,像 C 这样的原生语言非常适合这种用例。代码应该编译得更快,并且生成的代码也会更快,因为没有环绕。最重要的是,可以并行编译多个 C 文件,据我所知,这在 Numba 中是不可能的。
\n顺便说一句,大型分拣网络通常效率低下。首先,它们的大小对于大型代码来说不是最佳的(O(n log\xc2\xb2 n)当快速排序算法在O(n log n). xc2\xb5ops 缓存、循环流检测器、L1 指令缓存等)。我建议您重新考虑使用大于 15 的排序网络的需要,尤其是对于 Numba 代码。优化的插入排序可能比大排序网络更快因此,有几十件物品。
请注意,代码的“编译函数”部分是无用的,因为在提供签名时会急切地编译Numba 函数(代码中就是这种情况)。
\n经验法则是“不要重复自己”(DRY)!这显然对人类、编译器和 CPU 都不利。
\n| 归档时间: |
|
| 查看次数: |
91 次 |
| 最近记录: |