Pytorch Tensor如何获得特定值的索引

Han*_*ing 17 python pytorch

在python列表中,我们可以使用list.index(somevalue).pytorch怎么做到这一点?
例如:

    a=[1,2,3]
    print(a.index(2))
Run Code Online (Sandbox Code Playgroud)

然后,1将输出.pytorch张量如何在不将其转换为python列表的情况下执行此操作?

Man*_*nas 25

我认为没有直接转换list.index()为pytorch函数.但是,您可以使用tensor==number然后使用该nonzero()功能获得类似的结果.例如:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
Run Code Online (Sandbox Code Playgroud)

这段代码返回

1

[torch.LongTensor尺寸1x1]


dop*_*xxx 9

对于多维张量,您可以执行以下操作:

(tensor == target_value).nonzero(as_tuple=True)
Run Code Online (Sandbox Code Playgroud)

生成的张量的形状为number_of_matches x tensor_dimension。例如,假设tensor是一个3 x 4张量(这意味着维度为 2),结果将是一个二维张量,其中包含行中匹配项的索引。

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])
Run Code Online (Sandbox Code Playgroud)