我是面向对象的新手,并且在理解以下内容时遇到困难:
import torch.nn as nn
class mynet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(20, 64)
def forward(self, x):
x = self.fc1(x)
Run Code Online (Sandbox Code Playgroud)
该行self.fc1 = nn.Linear(20, 64) 应该为我的类创建一个成员变量 fc1,对吗?但是 nn.Linear(20, 64) 的返回值是多少?
根据文档, nn.Linear 定义为
class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)。
然而,在我的基本 OOP 教程中,我只看到类似class CLASSNAME(BASECLASS)CLASSNAME 类继承自 BASECLASS 的内容。该文档将所有这些内容写在括号之间的方式意味着什么?
另外,这条线x=fc1(x)不知何故让它看起来好像 fc1 现在是一个函数。
我在这里似乎缺乏 OOP 知识...任何帮助表示赞赏!