chm*_*777 5 python inheritance super pytorch
为什么super(LR, self).__init__()
需要在下面的代码中调用?否则我会收到错误“AttributeError: cannot assign module before Module.init () call”。该错误是由self.linear = nn.Linear(input_size, output_size)
.
我不明白调用super(LR, self).__init__()
和能够将 nn.Linear 对象分配给 self.linear之间有什么联系。nn.Linear 是一个单独的对象,它可以分配给任何类之外的变量,那么为什么super(LR, self).__init__()
需要调用将 Linear 对象分配给类内的 self.linear 呢?
class LR(nn.Module):
# Constructor
def __init__(self, input_size, output_size):
# Inherit from parent
super(LR, self).__init__()
self.test = 1
self.linear = nn.Linear(input_size, output_size)
# Prediction function
def forward(self, x):
out = self.linear(x)
return out
Run Code Online (Sandbox Code Playgroud)
当您self.linear = nn.Linear(...)
在自定义类中编写代码时,您实际上是在调用__setattr__
类的函数。碰巧的是,当您扩展时nn.Module
,您的类继承了很多东西,其中之一是__setattr__
. 正如你在实现中看到的(我只发布了下面的相关部分),如果nn.Linear
是 的实例nn.Module
,你的类必须有一个名为 的属性_modules
,否则它会抛出AttributeError
你得到的:
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
# [...]
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
Run Code Online (Sandbox Code Playgroud)
如果您查看nn.Module
's __init__
,您会看到它self._modules
在那里初始化:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict() # <---- here
Run Code Online (Sandbox Code Playgroud)
缓冲区和参数也是如此。
您需要 super() 调用,以便 mn.Module 类本身被初始化。在 Python 中,超类构造函数/初始化器不会自动调用 - 它们必须显式调用,这就是 super() 的作用 - 它计算出要调用的超类。
我假设您使用的是 Python 3 - 在这种情况下,您不需要 super() 调用中的参数 - 这就足够了:
super().__init__()
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1586 次 |
最近记录: |