AttributeError: 'torch.return_types.max' 对象没有属性 'dim' - Maxpooling Channel

Gug*_*lie 6 python computer-vision pytorch

我正在尝试对通道维度进行 maxpooling:

class ChannelPool(nn.Module):
    def forward(self, input):
        return torch.max(input, dim=1)
Run Code Online (Sandbox Code Playgroud)

但我得到了错误

AttributeError: 'torch.return_types.max' object has no attribute 'dim'
Run Code Online (Sandbox Code Playgroud)

Gug*_*lie 15

torch.max函数返回一个元组,因此:

class ChannelPool(nn.Module):
    def forward(self, input):
        input_max, input_indexes = torch.max(input, dim=1)
        return input_max
Run Code Online (Sandbox Code Playgroud)


Zha*_*oYi 6

我最近遇到了同样的错误。torch.max()有两种形式。

  • 如果你只给出一个输入张量(没有像dim...这样的其他参数),max() 函数将返回一个张量

  • if you specify other args (for example dim=0), max() function will returns a namedtuple: (values, indices). I guess the values is what you want.