pytorch Argmax 中发生冲突时的索引选择

dar*_*ffy 5 python argmax pytorch tensor

我一直在尝试学习张量运算,而这个让我陷入了困境。
假设我有一个张量 t:

    t = torch.tensor([
        [1,0,0,2],
        [0,3,3,0],
        [4,0,0,5]
    ], dtype  = torch.float32)
Run Code Online (Sandbox Code Playgroud)

现在这是一个 2 阶张量,我们可以对每个阶/维度应用 argmax。假设我们将其应用于 dim = 1

t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))
Run Code Online (Sandbox Code Playgroud)

现在我们可以看到结果正如预期的那样,沿 dim =1 的张量有 2,3 和 5 作为最大元素。但是3上有冲突。有两个值完全相似。
怎么解决的?是任意选择的吗?有没有像LR一样的选择顺序,指数值越高?
如果您能了解如何解决此问题,我将不胜感激!

cle*_*ros 6

这是一个很好的问题,我自己也曾多次被问到过。最简单的答案是,无法保证torch.argmax(或torch.max(x, dim=k),在指定 dim 时也返回索引)将一致地返回相同的索引。相反,它将返回argmax 值的任何有效索引,可能是随机的。正如官方论坛中的该帖子所讨论的那样,这被认为是期望的行为。(我知道我不久前读过的另一个线程使这一点更加明确,但我无法再次找到它)。

话虽如此,由于这种行为对于我的用例来说是不可接受的,我编写了以下函数来查找最左边和最右边的索引(请注意,这condition是您传入的函数对象):

def __consistent_args(input, condition, indices):
    assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
    mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
    return torch.argmax(mask, dim=1)


def consistent_find_leftmost(input, condition):
    indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)


def consistent_find_rightmost(input, condition):
    indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)

# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)                                                                                                                                     
# will return: 
# tensor([6])
Run Code Online (Sandbox Code Playgroud)

希望他们能帮忙!(哦,如果您有更好的实现相同功能,请告诉我)