pytorch 的 nn.Module 如何注册子模块?

JK.*_*ong 5 python pytorch

当我阅读 torch.nn.Module 的源代码(python)时,我发现该属性self._modules已在许多函数中使用,例如 self.modules(), self.children()等。但是,我没有找到任何更新它的函数。那么,在哪里self._modules更新呢?另外,pytorch 的nn.Module子模块是怎么注册的?

class Module(object):
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

    def named_modules(self, memo=None, prefix=''):
        if memo is None:
            memo = set()
        if self not in memo:
            memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix):
                    yield m
Run Code Online (Sandbox Code Playgroud)

小智 6

模块和参数通常通过为 的实例设置属性来注册nn.module。特别地,这种行为是通过自定义__setattr__方法来实现的:

def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)
Run Code Online (Sandbox Code Playgroud)

请参阅https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py找到此方法。


Yue*_*Tau 6

在金吉仁的回答中添加一些细节:

  • 网络的层(继承自nn.Module)存储在 中Module._modules,并在 中初始化__construct

    def __init__(self):
        self.__construct()
        # initialize self.training separately from the rest of the internal
        # state, as it is managed differently by nn.Module and ScriptModule
        self.training = True
    
    def __construct(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        # ...
        self._modules = OrderedDict()
    
    Run Code Online (Sandbox Code Playgroud)
  • self._modules更新于__setattr__. 执行__setattr__(obj, name, value)时被调用。obj.name = value例如,如果self.conv1 = nn.Conv2d(128, 256, 3, 1, 1)在初始化继承自 的网络时定义,则将执行nn.Module以下代码:nn.Module.__setattr__

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            # ...
        elif params is not None and name in params:
            # ...
        else:
            modules = self.__dict__.get('_modules') # equivalent to modules = self._modules
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                # register the given layer (nn.Conv2d) with its name (conv1)
                # equivalent to self._modules['conv1'] = nn.Conv2d(128, 256, 3, 1, 1)
                modules[name] = value
    
    Run Code Online (Sandbox Code Playgroud)

评论中的问题:

你知道这是如何与 torch 让你提供自己的转发方法这一事实一起工作的吗?

如果运行从 继承的网络的前向传递nn.Module,则将nn.Module.__call__被调用,其中self.forward被调用。然而,forward在实施网络时,人们已经超越了这一点。