我正在使用从list comprehension收集数据。代码如下:listtuples
data = [result[0] for result in results] #results is a list of tuples and i take first element from each tuple.
Run Code Online (Sandbox Code Playgroud)
这有效并且一切都很好。
最近我遇到了numba可以提高循环执行速度的模块?
所以我尝试这样做来测试时间:
import numba
from numba import literal_unroll
from datetime import datetime
import logging
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.jit(nopython=True)
def loop_faster(results):
for result in literal_unroll(results):
print(result)
tuples = (1.1, "Hello", 1, "World", "Tuple-1")
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(tuples)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for result in tuples:
print(result)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
Run Code Online (Sandbox Code Playgroud)
我引用了此链接literal_unroll:https ://numba.pydata.org/numba-doc/dev/reference/pysupported.html
然而, for 循环的执行效果似乎比numba方法要好得多。
上述程序的结果:
2021-03-04 10:51:36.385
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.234
1.1
Hello
1
World
Tuple-1
2021-03-04 10:51:47.236
Run Code Online (Sandbox Code Playgroud)
为什么会出现这种行为呢?numba 花了将近 10 秒
对于我的情况,从元组的第 n 个元素形成一个列表,我如何使用numba模块实现?
原因很简单,第一次运行函数时,很多时间都花在将你的代码编译成C++代码和机器代码上,这就是numba的JIT。
因此,您必须cache = True向@numba.jit装饰器添加参数以预缓存编译版本。此外,您还必须在测量时间之前调用一次运行以确保编译。此外,您还必须运行更多的循环迭代才能更精确地测量时间,仅运行 10 毫秒是不够的。
下面的代码做了上面提到的三件事。您可以看到 numba 带来了5.5x时间加速。
我还修改了您的代码以实现一些不同的逻辑,因为打印逻辑无法正确测量时间。Numba 适用于计算量很大的代码,而不是用于打印到控制台。所以我只是创建了随机整数数组并计算了这个数组+1。作为示例代码,这足以让您看到 Numba 运行得更快。
import numba, logging, random
from datetime import datetime
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
@numba.njit(cache = True)
def loop_faster(results, n):
for i in range(n):
res = []
for result in numba.literal_unroll(results):
res.append(result + 1)
t = tuple(random.randrange(1 << 20) for i in range(100))
loop_faster(t, 10) # pre-compile numba
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
loop_faster(t, 1 << 16)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
for i in range(1 << 16):
res = []
for result in t:
res.append(result + 1)
print(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
Run Code Online (Sandbox Code Playgroud)
输出:
2021-03-04 12:18:04.491
2021-03-04 12:18:04.840
2021-03-04 12:18:06.774
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
682 次 |
| 最近记录: |