Pytorch 张量 - 如何通过特定张量获取索引

do.*_*do. 4 python pytorch

我有张量

t = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]])
Run Code Online (Sandbox Code Playgroud)

和一个查询张量

q = torch.tensor([1, 0, 0, 0])
Run Code Online (Sandbox Code Playgroud)

有没有办法获得q像的索引

indexes = t.index(q) # get back [0, 3]
Run Code Online (Sandbox Code Playgroud)

在 pytorch 中?

Sha*_*hai 6

怎么样

In [1]: torch.nonzero((t == q).sum(dim=1) == t.size(1))
Out[1]: 
tensor([[ 0],
        [ 3]])
Run Code Online (Sandbox Code Playgroud)

Comparing在和之间t == q执行逐元素比较,因为您正在寻找整行匹配,您需要沿着行查看哪一行是完美匹配。tq.sum(dim=1)== t.size(1)


从 v0.4.1 开始,torch.all()支持dim参数:

torch.all(t==q, dim=1)
Run Code Online (Sandbox Code Playgroud)