roz*_*ang 11 numba pytorch tensor
我是 Numba 新手,我需要使用 Numba 来加速一些 Pytorch 功能。但我发现即使是一个非常简单的功能也不起作用:(
import torch
import numba
@numba.njit()
def vec_add_odd_pos(a, b):
res = 0.
for pos in range(len(a)):
if pos % 2 == 0:
res += a[pos] + b[pos]
return res
x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
Run Code Online (Sandbox Code Playgroud)
但出现以下错误
def vec_add_odd_pos(a, b):
res = 0.
^
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'torch.Tensor'>
- argument 1: cannot determine Numba type of <class 'torch.Tensor'>
Run Code Online (Sandbox Code Playgroud)
谁能帮我?包含更多示例的链接也将不胜感激。谢谢。
pix*_*lou 15
Pytorch 现在在 GPU 张量上公开了一个接口,可供 numba 直接使用:
numba.cuda.as_cuda_array(tensor)
测试脚本提供了一些使用示例:https://github.com/pytorch/pytorch/blob/master/test/test_numba_integration.py
Joh*_*mar 10
正如其他人提到的,numba 目前不支持 torch 张量,仅支持 numpy 张量。然而, TorchScript也有类似的目标。然后你的函数可以重写如下:
import torch
@torch.jit.script
def vec_add_odd_pos(a, b):
res = 0.
for pos in range(len(a)):
if pos % 2 == 0:
res += a[pos] + b[pos]
return res
x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
Run Code Online (Sandbox Code Playgroud)
请注意:虽然您说您的代码片段只是一个简单的示例,但 for 循环确实很慢并且运行 TorchScript 可能对您没有多大帮助,您应该不惜一切代价避免它们,并且仅在不存在其他解决方案时才使用 then。话虽这么说,以下是如何以更高效的方式实现您的函数:
def vec_add_odd_pos(a, b):
evenids = torch.arange(len(a)) % 2 == 0
return (a[evenids] + b[evenids]).sum()
Run Code Online (Sandbox Code Playgroud)