获取张量中每行的最大值 [PyTorch]

ch1*_*era 7 python max deep-learning pytorch

假设我有一个以下形式的张量

[[-5, 0, -1],
 [3, 100, 87],
 [17, -34, 2],
 [45, 1, 25]]
Run Code Online (Sandbox Code Playgroud)

我想找到每一行中的最大值并返回一个 1 级张量,如下所示:

[0,
 100,
 17,
 45]
Run Code Online (Sandbox Code Playgroud)

我该如何在 PyTorch 中做到这一点?

小智 10

您可以使用该torch.max()功能。所以你可以做类似的事情

x = torch.Tensor([[-5, 0, -1],
                  [3, 100, 87],
                  [17, -34, 2],
                  [45, 1, 25]])
out, inds = torch.max(x,dim=1)
Run Code Online (Sandbox Code Playgroud)

这将返回每行的最大值(维度 1)。它将返回最大值及其索引。