相关疑难解决方法(0)

Torch argmax 获取 N 个最高值

我想做一些类似 argmax 但有多个最高值的事情。我知道如何使用普通的 torch.argmax

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398,  1.2663, -0.2686,  0.2450],
        [-0.7401, -0.8805, -0.3402, -1.1936],
        [ 0.4907, -1.3948, -1.0691, -0.3132],
        [-1.6092,  0.5419, -0.2993,  0.3195]])
>>> torch.argmax(a)
tensor(0)
Run Code Online (Sandbox Code Playgroud)

但现在我需要找到前 N 个值的索引。所以像这样的事情

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398,  1.2663, -0.2686,  0.2450],
        [-0.7401, -0.8805, -0.3402, -1.1936],
        [ 0.4907, -1.3948, -1.0691, -0.3132],
        [-1.6092,  0.5419, -0.2993,  0.3195]])
>>> torch.argmax(a,top_n=2)
tensor([0,1])
Run Code Online (Sandbox Code Playgroud)

我在 pytorch 中没有找到任何能够执行此操作的函数,有人知道吗?

python machine-learning deep-learning pytorch

6
推荐指数
1
解决办法
1万
查看次数

如何使用 Pytorch 和/或 Numpy 高效查找多维矩阵数组中最大值的索引

背景

处理高维数据在机器学习中很常见。例如,在卷积神经网络 (CNN) 中,每个输入图像的尺寸可以是 256x256,并且每个图像可以具有 3 个颜色通道(红色、绿色和蓝色)。如果我们假设模型一次接收一批 16 张图像,则进入 CNN 的输入的维度为[16,3,256,256]。每个单独的卷积层都期望 形式的数据[batch_size, in_channels, in_y, in_x],并且所有这些数量通常会逐层变化(batch_size 除外)。我们用于表示由值组成的矩阵的术语[in_y, in_x]“特征映射”,这个问题涉及在给定层的每个特征映射中查找最大值及其索引。

我为什么要这样做?我想对每个特征图应用一个掩码,并且我想应用以每个特征图中的最大值为中心的掩码,为此,我需要知道每个最大值所在的位置。这种掩模应用是在模型的训练和测试期间完成的,因此效率对于减少计算时间至关重要。有许多 Pytorch 和 Numpy 解决方案可用于查找单例最大值和索引,以及查找沿单个维度的最大值或索引,但没有(我能找到)专用且高效的内置函数来查找最大值的索引一次沿着 2 个或更多维度。是的,我们可以嵌套在单个维度上运行的函数,但这些是一些效率最低的方法。

我尝试过的

  • 我看过这个 Stackoverflow 问题,但作者正在处理一个特殊情况的 4D 数组,它被简单地压缩为 3D 数组。接受的答案是专门针对这种情况的,而指向 TopK 的答案是误导性的,因为它不仅在单个维度上运行,而且k=1根据所提出的问题需要这样做,从而发展为常规torch.max调用。
  • 我看过这个 Stackoverflow 问题,但是这个问题及其答案,重点关注单一维度。
  • 我已经看过这个 Stackoverflow 问题,但我已经知道答案的方法,因为我在自己的答案中独立地表述了它我修改了该方法非常低效)。
  • 我看过这个Stackoverflow问题,但它不满足这个问题的关键部分,即与效率有关。
  • 我阅读了许多其他 Stackoverflow 问题和答案,以及 Numpy 文档、Pytorch 文档和 Pytorch 论坛上的帖子。
  • 我已经尝试实施很多不同的方法来解决这个问题,足以让我创建这个问题,以便我可以回答它并回馈社区以及将来寻找此问题解决方案的任何人。

绩效标准

如果我问有关效率的问题,我需要清楚地详细说明期望。我正在尝试为上述问题找到一种省时的解决方案(空间是次要的),而无需编写 C …

python numpy max numba pytorch

5
推荐指数
1
解决办法
2475
查看次数

从一维张量中提取前k个值索引

给定Torch中的一维张量(torch.Tensor),其中包含可以比较的值(例如浮点数),我们如何提取该张量中前k个值的索引?

除了蛮力方法外,我还在寻找Torch / lua提供的一些API调用,它可以有效地执行此任务。

lua torch

4
推荐指数
2
解决办法
3022
查看次数

标签 统计

python ×2

pytorch ×2

deep-learning ×1

lua ×1

machine-learning ×1

max ×1

numba ×1

numpy ×1

torch ×1