小编rel*_*100的帖子

Pytorch 的 nn.Linear(x,y) 返回什么?

我是面向对象的新手,并且在理解以下内容时遇到困难:

import torch.nn as nn

class mynet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(20, 64)
    
    def forward(self, x):
        x = self.fc1(x)

Run Code Online (Sandbox Code Playgroud)

该行self.fc1 = nn.Linear(20, 64) 应该为我的类创建一个成员变量 fc1,对吗?但是 nn.Linear(20, 64) 的返回值是多少?

根据文档, nn.Linear 定义为 class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)

然而,在我的基本 OOP 教程中,我只看到类似class CLASSNAME(BASECLASS)CLASSNAME 类继承自 BASECLASS 的内容。该文档将所有这些内容写在括号之间的方式意味着什么?

另外,这条线x=fc1(x)不知何故让它看起来好像 fc1 现在是一个函数。

我在这里似乎缺乏 OOP 知识...任何帮助表示赞赏!

oop class pytorch

1
推荐指数
1
解决办法
4788
查看次数

标签 统计

class ×1

oop ×1

pytorch ×1