PyTorch 张量高级索引

Dr.*_*ick 9 python numpy pytorch

假设我有一个矩阵和一个向量,如下所示:

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])
Run Code Online (Sandbox Code Playgroud)

有没有办法切片它x[y]所以结果是:

res = [1, 6, 8]
Run Code Online (Sandbox Code Playgroud)

所以基本上我取第一个元素y并取x对应于第一行和元素列的元素。

干杯

FBr*_*esi 7

您可以将相应的行索引指定为:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])
Run Code Online (Sandbox Code Playgroud)

  • 我能够毫无问题地重现这个答案。请修改您的问题@yarin (3认同)
  • 抱歉,我不小心输错了括号,现在可以了 (2认同)