这应该是一个快速的.当我在PyTorch中使用预定义模块时,我通常可以非常轻松地访问其权重.但是,如果我先将模块包装在nn.Sequential()中,如何访问它们?请看下面的玩具示例
class My_Model_1(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_1, self).__init__()
self.layer = nn.Linear(D_in,D_out)
def forward(self,x):
out = self.layer(x)
return out
class My_Model_2(nn.Module):
def __init__(self,D_in,D_out):
super(My_Model_2, self).__init__()
self.layer = nn.Sequential(nn.Linear(D_in,D_out))
def forward(self,x):
out = self.layer(x)
return out
model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)
# How do I print the weights now?
# model_2.layer.0.weight doesn't work.
Run Code Online (Sandbox Code Playgroud)
轻松获得重量的权重就是使用state_dict()
您的模型.
这适用于您的情况:
for k, v in model_2.state_dict().iteritems():
print("Layer {}".format(k))
print(v)
Run Code Online (Sandbox Code Playgroud)
另一个选择是获取modules()
迭代器.如果您事先知道图层的类型,这也应该有效:
for layer in model_2.modules():
if isinstance(layer, nn.Linear):
print(layer.weight)
Run Code Online (Sandbox Code Playgroud)
从PyTorch论坛,这是推荐的方式:
model_2.layer[0].weight
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
7615 次 |
最近记录: |