如何有效地获取Torch张量中的最大值索引?

Chr*_*ris 7 max indices pytorch

假定具有以下形状的火炬张量:

x = torch.rand(20, 1, 120, 120)
Run Code Online (Sandbox Code Playgroud)

我现在想要的是获取每个120x120矩阵的最大值的索引。为了简化问题,我将首先x.squeeze()使用shape [20, 120, 120]。然后,我想获取火炬张量,它是具有shape的索引列表[20, 2]

我该如何快速完成?

tej*_*i88 11

torch.topk()就是你要找的。从文档中,

torch.topk输入ķ暗淡=无最大=真排序=真OUT =无) - >(张量LongTensor

沿给定维度返回k 给定input张量的 最大元素 。

  • 如果 dim 未给出,则选择输入的最后一个维度。

  • 如果 largest 是, False 则返回 k 个最小元素。

  • 返回(值,索引)的命名元组,其中索引是原始输入张量中元素的索引。

  • 布尔选项 sorted if True,将确保返回的 k 元素本身已排序

  • 有用的功能知道,但它并没有回答原来的问题。OP 希望获得 20 个 120x120 矩阵中每个矩阵的最大元素的索引。也就是说,她想要 20 个 2D 坐标,每个矩阵一个。topk 仅返回最大化维度中最大元素的索引。 (7认同)

blu*_*nox 6

如果我理解正确,您不需要值,而是索引。不幸的是,没有现成的解决方案。存在一个argmax()函数,但我看不出如何让它完全按照你的意愿去做。

所以这是一个小的解决方法,效率也应该没问题,因为我们只是对张量进行除法:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)
Run Code Online (Sandbox Code Playgroud)

n代表您的第一个维度和d最后两个维度。我在这里取较小的数字来显示结果。但当然这也适用于n=20d=120

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)
Run Code Online (Sandbox Code Playgroud)

这是n=4and的输出d=4

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])
Run Code Online (Sandbox Code Playgroud)

我希望这就是你想要的!:)

编辑:

这是一个稍微修改过的,它可能会稍微快一点(我猜不是很多:),但它更简单和更漂亮:

而不是像以前那样:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
Run Code Online (Sandbox Code Playgroud)

已经对argmax值进行了必要的重塑:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)
Run Code Online (Sandbox Code Playgroud)

但正如评论中提到的。我认为不可能从中得到更多。

您可以做的一件事是,如果对您来说获得最后一点性能改进真的很重要,那就是将上述函数作为 pytorch 的低级扩展(如在 C++ 中)实现。

这只会给你一个你可以调用它的函数,并且会避免缓慢的 Python 代码。

https://pytorch.org/tutorials/advanced/cpp_extension.html