我想创建一个 Numba 编译的函数,它将指针或数组的内存地址作为参数并对其进行计算,例如修改基础数据。
用于说明这一点的纯 python 版本如下所示:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
def modify_data(addr):
""" a function taking the memory address of an array to modify it """
ptr = ctypes.c_void_p(addr)
data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
data += 2
addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])
Run Code Online (Sandbox Code Playgroud)
正如您在示例中看到的,数组arr已被修改,但未将其显式传递给函数。在我的用例中,数组的形状和数据类型是已知的,并且始终保持不变,这应该简化界面。
我现在尝试编译该modify_data函数,但失败了。我的第一次尝试是使用
shape = arr.shape
dtype = arr.dtype
@nb.njit
def modify_data_nb(ptr):
data …Run Code Online (Sandbox Code Playgroud)