为什么 PyTorch 自定义模块中需要超级构造函数?

chm*_*777 5 python inheritance super pytorch

为什么super(LR, self).__init__()需要在下面的代码中调用?否则我会收到错误“AttributeError: cannot assign module before Module.init () call”。该错误是由self.linear = nn.Linear(input_size, output_size).

我不明白调用super(LR, self).__init__()和能够将 nn.Linear 对象分配给 self.linear之间有什么联系。nn.Linear 是一个单独的对象,它可以分配给任何类之外的变量,那么为什么super(LR, self).__init__()需要调用将 Linear 对象分配给类内的 self.linear 呢?

class LR(nn.Module):
    
    # Constructor
    def __init__(self, input_size, output_size):
        
        # Inherit from parent
        super(LR, self).__init__()
        self.test = 1
        self.linear = nn.Linear(input_size, output_size)
        
    
    # Prediction function
    def forward(self, x):
        out = self.linear(x)
        return out
Run Code Online (Sandbox Code Playgroud)

Ber*_*iel 8

当您self.linear = nn.Linear(...)在自定义类中编写代码时,您实际上是在调用__setattr__类的函数。碰巧的是,当您扩展时nn.Module,您的类继承了很多东西,其中之一是__setattr__. 正如你在实现中看到的(我只发布了下面的相关部分),如果nn.Linear是 的实例nn.Module,你的类必须有一个名为 的属性_modules,否则它会抛出AttributeError你得到的:

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
    # [...]
    modules = self.__dict__.get('_modules')
    if isinstance(value, Module):
        if modules is None:
            raise AttributeError("cannot assign module before Module.__init__() call")
        remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
        modules[name] = value
Run Code Online (Sandbox Code Playgroud)

如果您查看nn.Module's __init__,您会看到它self._modules在那里初始化:

def __init__(self):
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._non_persistent_buffers_set = set()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict()                         # <---- here
Run Code Online (Sandbox Code Playgroud)

缓冲区和参数也是如此。

  • @chmod_777 开源是最好的:)如果您认为这回答了您的问题并且可能对其他人有帮助,请考虑投票或标记为答案。 (2认同)

Ton*_* 66 5

您需要 super() 调用,以便 mn.Module 类本身被初始化。在 Python 中,超类构造函数/初始化器不会自动调用 - 它们必须显式调用,这就是 super() 的作用 - 它计算出要调用的超类。

我假设您使用的是 Python 3 - 在这种情况下,您不需要 super() 调用中的参数 - 这就足够了:

        super().__init__()
Run Code Online (Sandbox Code Playgroud)