Pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?

Ima*_*bdi 5 python machine-learning pytorch

我有一个二维张量,每行都有一些非零元素,如下所示:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
Run Code Online (Sandbox Code Playgroud)

我想要一个包含每行中第一个非零元素索引的张量:

indices = tensor([2],
                 [3])
Run Code Online (Sandbox Code Playgroud)

我如何在 Pytorch 中计算它?

小智 11

我简化了 Iman 的方法来执行以下操作:

idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
Run Code Online (Sandbox Code Playgroud)


Sep*_*rvi 6

假设所有非零值都相等,argmax则返回第一个索引。

tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
Run Code Online (Sandbox Code Playgroud)


Ima*_*bdi 5

我可以为我的问题找到一个棘手的答案:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)
Run Code Online (Sandbox Code Playgroud)

结果是:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])
Run Code Online (Sandbox Code Playgroud)