Python:为什么 numba 需要更多时间?

use*_*187 1 python numba

我正在使用从list comprehension收集数据。代码如下:listtuples

data = [result[0] for result in results] #results is a list of tuples and i take first element from each tuple.
Run Code Online (Sandbox Code Playgroud)

这有效并且一切都很好。

最近我遇到了numba可以提高循环执行速度的模块?

所以我尝试这样做来测试时间:

import numba
from numba import literal_unroll
from datetime import datetime
import logging

numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

@numba.jit(nopython=True)
def loop_faster(results):
    for result in literal_unroll(results):
        print(result)
    
tuples = (1.1, "Hello", 1, "World", "Tuple-1")

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

loop_faster(tuples)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

for result in tuples:
        print(result)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
Run Code Online (Sandbox Code Playgroud)

我引用了此链接literal_unrollhttps ://numba.pydata.org/numba-doc/dev/reference/pysupported.html

然而, for 循环的执行效果似乎比numba方法要好得多。

上述程序的结果:

2021-03-04 10:51:36.385
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.234
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.236
Run Code Online (Sandbox Code Playgroud)

为什么会出现这种行为呢?numba 花了将近 10 秒

对于我的情况,从元组的第 n 个元素形成一个列表,我如何使用numba模块实现?

Art*_*oul 5

原因很简单,第一次运行函数时,很多时间都花在将你的代码编译成C++代码和机器代码上,这就是numba的JIT。

因此,您必须cache = True@numba.jit装饰器添加参数以预缓存编译版本。此外,您还必须在测量时间之前调用一次运行以确保编译。此外,您还必须运行更多的循环迭代才能更精确地测量时间,仅运行 10 毫秒是不够的。

下面的代码做了上面提到的三件事。您可以看到 numba 带来了5.5x时间加速。

我还修改了您的代码以实现一些不同的逻辑,因为打印逻辑无法正确测量时间。Numba 适用于计算量很大的代码,而不是用于打印到控制台。所以我只是创建了随机整数数组并计算了这个数组+1。作为示例代码,这足以让您看到 Numba 运行得更快。

在线尝试一下!

import numba, logging, random
from datetime import datetime

numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

@numba.njit(cache = True)
def loop_faster(results, n):
    for i in range(n):
        res = []
        for result in numba.literal_unroll(results):
            res.append(result + 1)
    
t = tuple(random.randrange(1 << 20) for i in range(100))

loop_faster(t, 10) # pre-compile numba

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

loop_faster(t, 1 << 16)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])

for i in range(1 << 16):
    res = []
    for result in t:
        res.append(result + 1)

print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
Run Code Online (Sandbox Code Playgroud)

输出:

2021-03-04 12:18:04.491
2021-03-04 12:18:04.840
2021-03-04 12:18:06.774
Run Code Online (Sandbox Code Playgroud)