多维张量的前 K 个指数

Mit*_*iku 1 python matrix-indexing pytorch tensor

我有一个二维张量,我想获得前 k 个值的索引。我知道pytorch 的 topk功能。pytorch 的 topk 函数的问题在于,它计算某个维度上的 topk 值。我想获得两个维度的 topk 值。

例如对于以下张量

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

Run Code Online (Sandbox Code Playgroud)

pytorch 的 topk 函数会给我以下信息。

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

Run Code Online (Sandbox Code Playgroud)

但我想得到以下

tensor([[0, 1],
        [2, 0],
        [3, 1]])

Run Code Online (Sandbox Code Playgroud)

这是 2D 张量中 9 的索引。

有什么方法可以使用 pytorch 实现这一点吗?

muj*_*iga 5

v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
Run Code Online (Sandbox Code Playgroud)

输出:

[[3 1]
 [2 0]
 [0 1]]
Run Code Online (Sandbox Code Playgroud)
  1. 展平并找到前 k
  2. 将一维索引转换为二维使用 unravel_index