相关疑难解决方法(0)

如何将数组指针传递给 Numba 函数?

我想创建一个 Numba 编译的函数,它将指针或数组的内存地址作为参数并对其进行计算,例如修改基础数据。

用于说明这一点的纯 python 版本如下所示:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array


def modify_data(addr):
    """ a function taking the memory address of an array to modify it """
    ptr = ctypes.c_void_p(addr)
    data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
    data += 2

addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])
Run Code Online (Sandbox Code Playgroud)

正如您在示例中看到的,数组arr已被修改,但未将其显式传递给函数。在我的用例中,数组的形状和数据类型是已知的,并且始终保持不变,这应该简化界面。

1.尝试:天真的抖动

我现在尝试编译该modify_data函数,但失败了。我的第一次尝试是使用

shape = arr.shape
dtype = arr.dtype

@nb.njit
def modify_data_nb(ptr):
    data …
Run Code Online (Sandbox Code Playgroud)

python arrays numpy numba

4
推荐指数
1
解决办法
2157
查看次数

标签 统计

arrays ×1

numba ×1

numpy ×1

python ×1