小编Mic*_*key的帖子

PyTorch:如何正确创建nn.Linear()的列表

我创建了一个具有nn.Module作为子类的类。

在我的课堂上,我必须创建N个线性变换,其中N作为类参数给出。

因此,我进行如下操作:

    self.list_1 = []

    for i in range(N):
        self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias))
Run Code Online (Sandbox Code Playgroud)

在forward方法中,我调用这些矩阵(具有list_1 [i])并合并结果。

两件事情 :

1)

即使我使用model.cuda(),这些线性变换仍在cpu上使用,但出现以下错误:

RuntimeError:类型为Variable [torch.cuda.FloatTensor]的预期对象,但为参数#1'mat2'找到类型Variable [torch.FloatTensor]

我要做

self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias).cuda())
Run Code Online (Sandbox Code Playgroud)

如果不是,这不是必需的,我这样做:

self.nn = nn.Linear(self.x, 1, bias=mlp_bias)
Run Code Online (Sandbox Code Playgroud)

然后直接使用self.nn。

2)

出于更明显的原因,当我在主菜单中打印(模型)时,列表中的线性矩阵也不会打印。

还有其他方法吗?也许使用bmm吗?我觉得不太容易,实际上我想分别得到N个结果。

先感谢您,

中号

python pytorch

4
推荐指数
1
解决办法
1635
查看次数

标签 统计

python ×1

pytorch ×1