Jai*_*tas 6 python machine-learning deep-learning pytorch
我想做一些类似 argmax 但有多个最高值的事情。我知道如何使用普通的 torch.argmax
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 1.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a)
tensor(0)
Run Code Online (Sandbox Code Playgroud)
但现在我需要找到前 N 个值的索引。所以像这样的事情
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 1.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a,top_n=2)
tensor([0,1])
Run Code Online (Sandbox Code Playgroud)
我在 pytorch 中没有找到任何能够执行此操作的函数,有人知道吗?
Great! So you need the first k largest elements of a tensor.
[答案 1]无论维度如何,您都需要所有元素中前 k 个最大的元素。因此,展平张量并使用该torch.topk函数获取前 3 个(例如)元素的索引:
>>> a = torch.randn(5,4)
>>> a
tensor([[ 0.8292, -0.5123, -0.0741, -0.3043],
[-0.4340, -0.7763, 1.9716, -0.5620],
[ 0.1582, -1.2000, 1.0202, -1.5202],
[-0.3617, -0.2479, 0.6204, 0.2575],
[ 1.8025, 1.9864, -0.8013, -0.7508]])
>>> torch.topk(a.flatten(), 3).indices
tensor([17, 6, 16])
Run Code Online (Sandbox Code Playgroud)
[答案 2]您需要给定输入张量沿给定维度的 k 个最大元素。因此,请参阅此处torch.topk给出的函数的 PyTorch 文档。