小编ire*_*ire的帖子

Numba python3 出现错误 [GPU ufunc 要求数组参数具有确切的类型。]

我正在尝试使用 numba 在我的 GPU 上执行 np.diff。
这是我使用的脚本;

import numpy as np
import numba

@numba.vectorize(["float32(float32, float32)"], target='cuda')
def vector_diff_axis0(a, b):
    return a + b

def my_diff(A, axis=0):
    if (axis == 0):
        return vector_diff_axis0(A[1:], A[:-1])
    if (axis == 1):
        return vector_diff_axis0(A[:,1:], A[:,:-1])

A = np.matrix([
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
    [9, 8, 7, 6, 5],
    [4, 3, 2, 1, 0],
    [0, 2, 4, 6, 8]
    ], dtype='float32')

C = my_diff(A, axis=1)
print (str(C))
Run Code Online (Sandbox Code Playgroud)

这是我得到的错误;

TypeError: No matching …
Run Code Online (Sandbox Code Playgroud)

python numpy python-3.x numba

5
推荐指数
1
解决办法
1898
查看次数

标签 统计

numba ×1

numpy ×1

python ×1

python-3.x ×1