best_state 在 pytorch 训练期间随模型变化

Lei*_*Bai 4 python ordereddictionary pytorch

我想保存最好的模型,然后在测试期间加载它。所以我使用了以下方法:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  
Run Code Online (Sandbox Code Playgroud)

然后,在我使用的主函数中:

model.load_state_dict(best_state)  
Run Code Online (Sandbox Code Playgroud)

恢复模型。

但是,我发现 best_state 总是与训练时的最后一个状态相同,而不是最佳状态。有谁知道原因以及如何避免它?

顺便说一下,我知道我可以使用torch.save(the_model.state_dict(), PATH)然后通过 the_model.load_state_dict(torch.load(PATH)). 但是,我不想将参数保存到文件中,因为训练和测试函数在一个文件中。

pro*_*sti 5

model.state_dict()OrderedDict

from collections import OrderedDict
Run Code Online (Sandbox Code Playgroud)

您可以使用:

from copy import deepcopy
Run Code Online (Sandbox Code Playgroud)

解决问题

反而:

best_state = model.state_dict() 
Run Code Online (Sandbox Code Playgroud)

你应该使用:

best_state = copy.deepcopy(model.state_dict())
Run Code Online (Sandbox Code Playgroud)

深(非浅)复制使可变 OrderedDict 实例不会best_state随着它发生变化。

您可以查看我关于在 PyTorch 中保存状态字典的其他答案