小编Pro*_*mer的帖子

为什么dim=1 在torch.argmax 中返回行索引?

我正在argmax研究 PyTorch 的功能,其定义为:

torch.argmax(input, dim=None, keepdim=False)
Run Code Online (Sandbox Code Playgroud)

考虑一个例子

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
Run Code Online (Sandbox Code Playgroud)

在这里,当我使用 dim=1 而不是搜索列向量时,该函数会搜索行向量,如下所示。

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])
Run Code Online (Sandbox Code Playgroud)

就我的假设而言,dim = 0 代表行,dim = 1 代表列。

python matrix argmax pytorch tensor

10
推荐指数
1
解决办法
5343
查看次数

标签 统计

argmax ×1

matrix ×1

python ×1

pytorch ×1

tensor ×1