如何将 Numba 用于 Pytorch 张量?

roz*_*ang 11 numba pytorch tensor

我是 Numba 新手,我需要使用 Numba 来加速一些 Pytorch 功能。但我发现即使是一个非常简单的功能也不起作用:(

import torch
import numba

@numba.njit()
def vec_add_odd_pos(a, b):
    res = 0.
    for pos in range(len(a)):
        if pos % 2 == 0:
            res += a[pos] + b[pos]
    return res

x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
Run Code Online (Sandbox Code Playgroud)

但出现以下错误

def vec_add_odd_pos(a, b):
    res = 0.
    ^

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'torch.Tensor'>
- argument 1: cannot determine Numba type of <class 'torch.Tensor'>
Run Code Online (Sandbox Code Playgroud)

谁能帮我?包含更多示例的链接也将不胜感激。谢谢。

pix*_*lou 15

Pytorch 现在在 GPU 张量上公开了一个接口,可供 numba 直接使用:

numba.cuda.as_cuda_array(tensor)

测试脚本提供了一些使用示例:https://github.com/pytorch/pytorch/blob/master/test/test_numba_integration.py


Joh*_*mar 10

正如其他人提到的,numba 目前不支持 torch 张量,仅支持 numpy 张量。然而, TorchScript也有类似的目标。然后你的函数可以重写如下:

import torch

@torch.jit.script
def vec_add_odd_pos(a, b):
    res = 0.
    for pos in range(len(a)):
        if pos % 2 == 0:
            res += a[pos] + b[pos]
    return res

x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
Run Code Online (Sandbox Code Playgroud)

请注意:虽然您说您的代码片段只是一个简单的示例,但 for 循环确实很慢并且运行 TorchScript 可能对您没有多大帮助,您应该不惜一切代价避免它们,并且仅在不存在其他解决方案时才使用 then。话虽这么说,以下是如何以更高效的方式实现您的函数:

def vec_add_odd_pos(a, b):
    evenids = torch.arange(len(a)) % 2 == 0
    return (a[evenids] + b[evenids]).sum()
Run Code Online (Sandbox Code Playgroud)

  • _for 循环真的很慢。_ 这不正是 numba 试图解决的问题吗?即,您有某种逻辑无法(或非常尴尬)矢量化的情况?在 numba 中实现 for 循环当然并不慢。如果 TorchScript 中的 for 循环很慢,那么它的目标与 numba 并不相似。 (2认同)