Gug*_*lie 6 python computer-vision pytorch
我正在尝试对通道维度进行 maxpooling:
class ChannelPool(nn.Module):
def forward(self, input):
return torch.max(input, dim=1)
Run Code Online (Sandbox Code Playgroud)
但我得到了错误
AttributeError: 'torch.return_types.max' object has no attribute 'dim'
Run Code Online (Sandbox Code Playgroud)
Gug*_*lie 15
该torch.max函数返回一个元组,因此:
class ChannelPool(nn.Module):
def forward(self, input):
input_max, input_indexes = torch.max(input, dim=1)
return input_max
Run Code Online (Sandbox Code Playgroud)
我最近遇到了同样的错误。torch.max()有两种形式。
如果你只给出一个输入张量(没有像dim...这样的其他参数),max() 函数将返回一个张量
if you specify other args (for example dim=0), max() function will returns a namedtuple: (values, indices). I guess the values is what you want.
| 归档时间: |
|
| 查看次数: |
7978 次 |
| 最近记录: |