pytorch nn.Sequential(*list) TypeError:列表不是模块子类

Ele*_*ogs 4 python pytorch

当我使用pytorch训练模型时,我尝试打印整个网络结构

所以我将所有图层打包在一个列表中然后使用nn.Sequential(*list)

但它不起作用,并且 TypeError: list 不是 Module 子类

小智 10

请提供您创建的图层列表,您确定您没有在其中犯任何错误。尝试检查您的列表是否实际上是 [] 而不是 [[..]]。我注意到的另一件事是你有list一个变量名,这不是一个好主意 -list是一个Python关键字。

我尝试编写一个解压列表的示例代码,它对我来说效果很好。

import torch
import torch.nn as nn                                                                           net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))                                           
layers = [nn.Linear(2, 2), nn.Linear(2, 2)]                                                  
net = nn.Sequential(*layers)
print(net)
Run Code Online (Sandbox Code Playgroud)

运行没有任何错误,结果是:

Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
Run Code Online (Sandbox Code Playgroud)

希望这可以帮助。:)