在torch.sum() 中dim=-1 或-2 是什么意思?

sky*_*ark 8 python pytorch

让我以一个二维矩阵为例:

mat = torch.arange(9).view(3, -1)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

torch.sum(mat, dim=-2)

tensor([ 9, 12, 15])
Run Code Online (Sandbox Code Playgroud)

我发现的结果torch.sum(mat, dim=-2)等于torch.sum(mat, dim=0)dim=-1等于dim=1。我的问题是如何理解这里的负面维度。如果输入矩阵有 3 个或更多维度怎么办?

cri*_*tig 12

张量有多个维度,如下图所示。有向前和向后索引。正向索引使用正整数,反向索引使用负整数。

例子:

-1 将是最后一个,在我们的例子中它将是 dim=2

-2 会变暗=1

-3 将变暗=0

在此处输入图片说明

  • 通常从零而不是从 1 开始计数 (7认同)

Sim*_*mdi 5

减号基本上意味着您向后浏览维度。设 A 是一个 n 维矩阵。然后dim=n-1=-1,dim=n-2=-2,...,dim=1=-(n-1),dim=0=-n。有关更多信息,请参阅numpy 文档,因为 pytorch 很大程度上基于 numpy。