自动编码器 MaxUnpool2d 缺少“索引”参数

Roo*_*ter 4 pytorch

以下模型返回错误: TypeError:forward() Missing 1 requiredpositional argument: 'indices'

我已经用尽了许多在线示例,它们看起来都与我的代码相似。我的 maxpool 层返回 unpool 层的输入和索引。关于出了什么问题有什么想法吗?

class autoencoder(nn.Module):
def __init__(self):
    super(autoencoder, self).__init__()
    self.encoder = nn.Sequential(
        ...
        nn.MaxPool2d(2, stride=1, return_indices=True)
    )
    self.decoder = nn.Sequential(
        nn.MaxUnpool2d(2, stride=1),
        ...
    )

def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)
    return x
Run Code Online (Sandbox Code Playgroud)

Roo*_*ter 6

与这里的问题类似,解决方案似乎是将 maxunpool 层与解码器分开并显式传递其所需的参数。nn.Sequential仅采用一个参数。

class SimpleConvAE(nn.Module):
def __init__(self):
    super().__init__()

    # input: batch x 3 x 32 x 32 -> output: batch x 16 x 16 x 16
    self.encoder = nn.Sequential(
        ...
        nn.MaxPool2d(2, stride=2, return_indices=True),
    )

    self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)

    self.decoder = nn.Sequential(
        ...
    )

def forward(self, x):
    encoded, indices = self.encoder(x)
    out = self.unpool(encoded, indices)
    out = self.decoder(out)
    return (out, encoded)
Run Code Online (Sandbox Code Playgroud)