如何编写在CPU模式和CUDA设备模式下使用的Numba函数?

roz*_*ang -1 cuda numba

我想编写一个在 CPU 模式和 CUDA 设备模式下使用的 Numba 函数。当然,我可以使用和不使用 cuda.jit 装饰器编写两个相同的函数。例如:

from numba import cuda, njit

@njit("i4(i4, i4)")
def func_cpu(a, b)
    return a + b

@cuda.jit("i4(i4, i4)", device=True)
def func_gpu(a, b)
    return a + b
Run Code Online (Sandbox Code Playgroud)

但在软件工程中它是丑陋的。有没有一种更优雅的方式,即将代码组合在一个函数中?

Rut*_*ies 6

装饰器本质上是一个函数,它将函数作为输入,并返回一个(经常修改的)函数作为输出。像 Numba 一样添加参数和关键字参数使其变得稍微复杂一些(内部),但您可以将其视为一个嵌套函数,其中外部函数再次返回一个装饰器。

@因此,您可以将其作为任何函数调用并捕获输出,而不是像现在一样将其用作装饰器(使用 ) 。然后输出也将是一个可调用函数。

这允许用纯 Python 编写函数,然后根据需要应用任意数量的“装饰器”。例如:

from numba import cuda, njit

def func_py(a, b)
    return a + b

func_njit = njit("i4(i4, i4)")(func_py)
func_gpu = cuda.jit("i4(i4, i4)", device=True)(func_py)

assert func_py(4, 3) == func_njit(4, 3)
assert func_py(4, 3) == func_gpu(4, 3)
Run Code Online (Sandbox Code Playgroud)