PyTorch:state_dict和parameters()有什么区别?

Gul*_*zar 4 python machine-learning deep-learning pytorch

为了在pytorch中访问模型的参数,我看到了两种方法:

使用state_dict使用parameters()

我不知道有什么区别,或者一个是好的做法,另一个是不好的做法。

谢谢

kHa*_*hit 6

parameters()只给出参数,即重量和偏见的模块。

返回模块参数上的迭代器。

您可以按如下方式检查参数列表:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
Run Code Online (Sandbox Code Playgroud)

另一方面,state_dict 返回包含模块整个状态的字典。检查其source code不仅包含对的调用,parameters还包含buffers,等等。

包括参数和持久性缓冲区(例如运行平均值)。键是相应的参数和缓冲区名称。

state_dict使用以下命令检查包含的所有键:

model.state_dict().keys()
Run Code Online (Sandbox Code Playgroud)

例如,在中state_dict,您会找到bn1.running_meanrunning_var中没有的条目,例如和.parameters()


如果您只想访问参数,则可以简单地使用.parameters(),而出于传输学习中保存和加载模型的目的,您不仅需要保存state_dict参数。