Jam*_*gan 6 python memoization numba
我刚刚发现了 numba,并了解到最佳性能需要添加@njit到大多数函数中,因此 numba 很少退出 LLVM 模式。
我仍然有一些昂贵的/查找功能可以从记忆中受益,但到目前为止,我的尝试都没有找到一个可以无错误编译的可行解决方案。
@njit导致 numba 无法进行类型推断。@njit编译装饰器失败后使用装饰器global变量,即使使用numba.typed.Dict@njit从其他@njit函数调用时,删除也会导致类型错误在 numba 中工作时向函数添加记忆的正确方法是什么?
import functools
import time
import fastcache
import numba
import numpy as np
import toolz
from numba import njit
from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None) # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None) # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
# Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
global _expensive_cache
if _expensive_cache is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
_expensive_cache = bitmasks
return _expensive_cache
# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
# is not supported as globals are considered compile-time constants and there is no known way to compile
# a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
key_type = numba.types.string,
value_type = numba.uint64[:]
)
@njit
def expensive_cache():
global cache
if "expensive" not in cache:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache["expensive"] = bitmasks
return cache["expensive"]
# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
cache = []
def wrapper(bitmasks):
if len(cache) is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache.append(bitmasks)
return cache[0]
return wrapper
expensive_wrapped = _expensive_wrapped()
@njit
def loop(count):
for n in range(count):
expensive()
# expensive_nojit()
# expensive_cache()
# expensive_global)
# expensive_wrapped()
def main():
time_start = time.perf_counter()
count = 10000
loop(count)
time_taken = time.perf_counter() - time_start
print(f'{count} loops in {time_taken:.4f}s')
loop(1) # precache numba
main()
# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s
Run Code Online (Sandbox Code Playgroud)
您已经提到您的实际代码更复杂,但是看看您的最小示例,我会推荐以下模式:
@njit
def loop(count):
expensive_result = expensive()
for i in range(count):
do_something(count, expensive_result)
Run Code Online (Sandbox Code Playgroud)
您可以在循环外部预先计算并将结果提供给循环体,而不是使用缓存。我建议您显式传递每个参数,而不是使用全局变量(总是如此,但尤其是在使用 numba jit 时)。