小编Ima*_*bdi的帖子

Pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?

我有一个二维张量,每行都有一些非零元素,如下所示:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
Run Code Online (Sandbox Code Playgroud)

我想要一个包含每行中第一个非零元素索引的张量:

indices = tensor([2],
                 [3])
Run Code Online (Sandbox Code Playgroud)

我如何在 Pytorch 中计算它?

python machine-learning pytorch

5
推荐指数
3
解决办法
2633
查看次数

标签 统计

machine-learning ×1

python ×1

pytorch ×1