从一维张量中提取前k个值索引

Ada*_*han 4 lua torch

给定Torch中的一维张量(torch.Tensor),其中包含可以比较的值(例如浮点数),我们如何提取该张量中前k个值的索引?

除了蛮力方法外,我还在寻找Torch / lua提供的一些API调用,它可以有效地执行此任务。

Cha*_*tor 8

您可以使用topk函数。

例如:

import torch

t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])

values,indices = t.topk(2)

print(values)
print(indices)
Run Code Online (Sandbox Code Playgroud)

结果:

tensor([9.5000, 6.1000])
tensor([2, 4])
Run Code Online (Sandbox Code Playgroud)


del*_*eil 5

自发出请求以来,#496 Torch现在包括一个名为的内置API torch.topk。例:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]
Run Code Online (Sandbox Code Playgroud)

在撰写本文时,CPU实现遵循一种狭窄的方法(有计划在将来进行改进)。话虽如此,目前正在审查针对割炬的优化GPU实现。