如何在 Pytorch 中使用 torch.nn.Sequential 实现我自己的 ResNet?

Cha*_*ker 5 machine-learning neural-network deep-learning conv-neural-network pytorch

我想实现一个 ResNet 网络(或者更确切地说,残差块),但我真的希望它采用顺序网络形式。

我所说的顺序网络形式如下:

## mdl5, from cifar10 tutorial
mdl5 = nn.Sequential(OrderedDict([
    ('pool1', nn.MaxPool2d(2, 2)),
    ('relu1', nn.ReLU()),
    ('conv1', nn.Conv2d(3, 6, 5)),
    ('pool1', nn.MaxPool2d(2, 2)),
    ('relu2', nn.ReLU()),
    ('conv2', nn.Conv2d(6, 16, 5)),
    ('relu2', nn.ReLU()),
    ('Flatten', Flatten()),
    ('fc1', nn.Linear(1024, 120)), # figure out equation properly
    ('relu4', nn.ReLU()),
    ('fc2', nn.Linear(120, 84)),
    ('relu5', nn.ReLU()),
    ('fc3', nn.Linear(84, 10))
]))
Run Code Online (Sandbox Code Playgroud)

但当然,NN 乐高积木是“ResNet”。

我知道等式是这样的:

在此处输入图片说明

但我不确定如何在 Pytorch AND Sequential 中做到这一点。顺序对我来说很关键!


交叉发布:

Szy*_*zke 7

你不能单独使用它,torch.nn.Sequential因为它需要操作,顾名思义,按顺序进行,而你的操作是并行的。

原则上,您block可以像这样轻松构建自己的:

import torch

class ResNet(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs
Run Code Online (Sandbox Code Playgroud)

哪一个可以使用这样的东西:

model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 32, kernel_size=7),
    # 32 filters in and out, no max pooling so the shapes can be added
    ResNet(
        torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
        )
    ),
    # Another ResNet block, you could make more of them
    # Downsampling using maxpool and others could be done in between etc. etc.
    ResNet(
        torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
        )
    ),
    # Pool all the 32 filters to 1, you may need to use `torch.squeeze after this layer`
    torch.nn.AdaptiveAvgPool2d(1),
    # 32 10 classes
    torch.nn.Linear(32, 10),
)
Run Code Online (Sandbox Code Playgroud)

通常被忽视的事实(在涉及浅层网络时没有真正的后果)是跳过连接应该没有任何非线性,如ReLU卷积层,这就是你在上面看到的(来源:深度残差网络中的身份映射)。