Numba 兼容的记忆

Jam*_*gan 6 python memoization numba

我刚刚发现了 numba,并了解到最佳性能需要添加@njit到大多数函数中,因此 numba 很少退出 LLVM 模式。

我仍然有一些昂贵的/查找功能可以从记忆中受益,但到目前为止,我的尝试都没有找到一个可以无错误编译的可行解决方案。

  • 使用通用装饰器函数,之前@njit导致 numba 无法进行类型推断。
  • @njit编译装饰器失败后使用装饰器
  • Numba 不喜欢使用global变量,即使使用numba.typed.Dict
  • Numba 不喜欢使用闭包来存储可变状态
  • @njit从其他@njit函数调用时,删除也会导致类型错误

在 numba 中工作时向函数添加记忆的正确方法是什么?

import functools
import time

import fastcache
import numba
import numpy as np
import toolz
from numba import njit

from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks


# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache


# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
)
@njit
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]


# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
            cache.append(bitmasks)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()

@njit
def loop(count):
    for n in range(count):
        expensive()
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()

def main():
    time_start = time.perf_counter()

    count = 10000
    loop(count)

    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')


loop(1)  # precache numba
main()

# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s
Run Code Online (Sandbox Code Playgroud)

Car*_*orn 1

您已经提到您的实际代码更复杂,但是看看您的最小示例,我会推荐以下模式:

@njit
def loop(count):
    expensive_result = expensive()
    for i in range(count):
        do_something(count, expensive_result)
Run Code Online (Sandbox Code Playgroud)

您可以在循环外部预先计算并将结果提供给循环体,而不是使用缓存。我建议您显式传递每个参数,而不是使用全局变量(总是如此,但尤其是在使用 numba jit 时)。