在pytorch中重置神经网络的参数

lea*_*ner 7 neural-network python-3.x gated-recurrent-unit pytorch

我有一个具有以下结构的神经网络:

class myNetwork(nn.Module):
    def __init__(self):
        super(myNetwork, self).__init__()
        self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(200, 32)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(32, 2)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
Run Code Online (Sandbox Code Playgroud)

我需要通过重置神经网络的参数来将模型恢复到未学习的状态。nn.Linear我可以使用以下方法对图层执行此操作:

def reset_weights(self):
    torch.nn.init.xavier_uniform_(self.fc1.weight)
    torch.nn.init.xavier_uniform_(self.fc2.weight)
Run Code Online (Sandbox Code Playgroud)

但是,要重置图层的权重nn.GRU,我找不到任何此类片段。

我的问题是如何重置图层nn.GRU?重置网络的任何其他方法也可以。任何帮助表示赞赏。

Dis*_*ani 12

reset_parameters您可以在图层上使用方法。正如这里给出的

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()
Run Code Online (Sandbox Code Playgroud)

或者另一种方法是先保存模型,然后重新加载模块状态。使用torch.savetorch.load 查看文档了解更多信息保存和加载模型