Pytorch 的 nn.Linear(x,y) 返回什么?

rel*_*100 1 oop class pytorch

我是面向对象的新手,并且在理解以下内容时遇到困难:

import torch.nn as nn

class mynet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(20, 64)
    
    def forward(self, x):
        x = self.fc1(x)

Run Code Online (Sandbox Code Playgroud)

该行self.fc1 = nn.Linear(20, 64) 应该为我的类创建一个成员变量 fc1,对吗?但是 nn.Linear(20, 64) 的返回值是多少?

根据文档, nn.Linear 定义为 class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)

然而,在我的基本 OOP 教程中,我只看到类似class CLASSNAME(BASECLASS)CLASSNAME 类继承自 BASECLASS 的内容。该文档将所有这些内容写在括号之间的方式意味着什么?

另外,这条线x=fc1(x)不知何故让它看起来好像 fc1 现在是一个函数。

我在这里似乎缺乏 OOP 知识...任何帮助表示赞赏!

jod*_*dag 5

首先让我们看一下这个

self.fc1 = nn.Linear(20, 64)
Run Code Online (Sandbox Code Playgroud)

对于对 Python 和 OOP 有基本了解的人来说,这部分可能很熟悉。在这里,我们只是创建一个类的新实例,并使用位置参数和分别对应于和 来nn.Linear初始化该类。文档中的参数是要传递给 的方法的预期参数。2064in_featuresout_featuresnn.Linear__init__

现在来说说可能有点令人困惑的部分

x = self.fc1(x)
Run Code Online (Sandbox Code Playgroud)

该类nn.Linear可调用的,因为它的父类nn.Module实现了一个名为 的特殊方法__call__self.fc1这意味着您可以像函数一样对待并执行类似的操作x = self.fc1(x),这相当于x = self.fc1.__call__(x).