Dav*_*ker 4 python arrays numpy numba
我想创建一个 Numba 编译的函数,它将指针或数组的内存地址作为参数并对其进行计算,例如修改基础数据。
用于说明这一点的纯 python 版本如下所示:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
def modify_data(addr):
""" a function taking the memory address of an array to modify it """
ptr = ctypes.c_void_p(addr)
data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
data += 2
addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])
Run Code Online (Sandbox Code Playgroud)
正如您在示例中看到的,数组arr已被修改,但未将其显式传递给函数。在我的用例中,数组的形状和数据类型是已知的,并且始终保持不变,这应该简化界面。
我现在尝试编译该modify_data函数,但失败了。我的第一次尝试是使用
shape = arr.shape
dtype = arr.dtype
@nb.njit
def modify_data_nb(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr) # <<< error
Run Code Online (Sandbox Code Playgroud)
这失败了cannot determine Numba type of <class 'ctypes.c_void_p'>,即它不知道如何解释指针。
我尝试放置显式类型,
arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape
@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
""" a function taking the memory address of an array to modify it """
data = nb.carray(ptr, shape)
data += 2
Run Code Online (Sandbox Code Playgroud)
但这没有帮助。它没有抛出任何错误,但我不知道如何调用该函数modify_data_nb。我尝试了以下选项
modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
Run Code Online (Sandbox Code Playgroud)
有没有办法获得正确的指针格式,arr以便我可以将其传递给 Numba 编译的modify_data_nb函数?或者,是否有另一种方法将内存位置传递给函数。
scipy.LowLevelCallablescipy.LowLevelCallable我通过使用它的魔力取得了一些进展:
arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])
# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype
@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
modify_data_llc = LowLevelCallable(modify_data.ctypes).function
# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
# call the function only with the pointer
modify_data_llc(ptr)
# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])
Run Code Online (Sandbox Code Playgroud)
我现在可以调用函数来访问数组,但该函数不再是 Numba 函数。特别是,它不能用于其他 Numba 功能。
感谢伟大的@stuartarchibald,我现在有了一个可行的解决方案:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
print(arr)
@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
""" returns a void pointer from a given memory address """
from numba.core import types, cgutils
sig = types.voidptr(src)
def codegen(cgctx, builder, sig, args):
return builder.inttoptr(args[0], cgutils.voidptr_t)
return sig, codegen
addr = arr.ctypes.data
@nb.njit
def modify_data():
""" a function taking the memory address of an array to modify it """
data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
data += 2
modify_data()
print(arr)
Run Code Online (Sandbox Code Playgroud)
关键是新address_as_void_pointer函数将内存地址(以 int 形式给出)转换为可由numba.carray.
| 归档时间: |
|
| 查看次数: |
2157 次 |
| 最近记录: |