我是面向对象的新手,并且在理解以下内容时遇到困难:
import torch.nn as nn
class mynet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(20, 64)
def forward(self, x):
x = self.fc1(x)
Run Code Online (Sandbox Code Playgroud)
该行self.fc1 = nn.Linear(20, 64) 应该为我的类创建一个成员变量 fc1,对吗?但是 nn.Linear(20, 64) 的返回值是多少?
根据文档, nn.Linear 定义为
class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)。
然而,在我的基本 OOP 教程中,我只看到类似class CLASSNAME(BASECLASS)CLASSNAME 类继承自 BASECLASS 的内容。该文档将所有这些内容写在括号之间的方式意味着什么?
另外,这条线x=fc1(x)不知何故让它看起来好像 fc1 现在是一个函数。
我在这里似乎缺乏 OOP 知识...任何帮助表示赞赏!
首先让我们看一下这个
self.fc1 = nn.Linear(20, 64)
Run Code Online (Sandbox Code Playgroud)
对于对 Python 和 OOP 有基本了解的人来说,这部分可能很熟悉。在这里,我们只是创建一个类的新实例,并使用位置参数和分别对应于和 来nn.Linear初始化该类。文档中的参数是要传递给 的方法的预期参数。2064in_featuresout_featuresnn.Linear__init__
现在来说说可能有点令人困惑的部分
x = self.fc1(x)
Run Code Online (Sandbox Code Playgroud)
该类nn.Linear是可调用的,因为它的父类nn.Module实现了一个名为 的特殊方法__call__。self.fc1这意味着您可以像函数一样对待并执行类似的操作x = self.fc1(x),这相当于x = self.fc1.__call__(x).