我正在关注这篇文章,并尝试实现此功能:
def replace_max_pooling(model):
'''
The function replaces max pooling layers with average pooling layers with
the following properties: kernel_size=2, stride=2, padding=0.
'''
for layer in model.layers:
if layer is max pooling:
replace
Run Code Online (Sandbox Code Playgroud)
但是我在迭代中得到一个错误说:
ModuleAttributeError: 'VGG' 对象没有属性 'layers'...
我怎样才能正确地做到这一点?
Torchvision 提供的 VGG 模型包含三个组件:features
子模块avgpool
(自适应平均池)和classifier
. 你需要寻找到网络的头部,其中卷积和泳池层位于:features
。
您可以在各层循环nn.Module
使用named_children()
。但是,还有其他方法可以解决此问题。您可以使用它isinstance
来确定图层是否属于特定类型。
在这个特定模型中,层由那里的索引命名。因此,为了在 中找到适当的层nn.Module
并覆盖它们,我们可以将名称转换为 int。
for i, layer in m.features.named_children():
if isinstance(layer, torch.nn.MaxPool2d):
m.features[int(i)] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
Run Code Online (Sandbox Code Playgroud)
预先设置:
import torch
import torch.nn as nn
m = models.vgg16()
Run Code Online (Sandbox Code Playgroud)