Lei*_*Bai 4 python ordereddictionary pytorch
我想保存最好的模型,然后在测试期间加载它。所以我使用了以下方法:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
Run Code Online (Sandbox Code Playgroud)
然后,在我使用的主函数中:
model.load_state_dict(best_state)
Run Code Online (Sandbox Code Playgroud)
恢复模型。
但是,我发现 best_state 总是与训练时的最后一个状态相同,而不是最佳状态。有谁知道原因以及如何避免它?
顺便说一下,我知道我可以使用torch.save(the_model.state_dict(), PATH)然后通过
the_model.load_state_dict(torch.load(PATH)). 但是,我不想将参数保存到文件中,因为训练和测试函数在一个文件中。
model.state_dict() 是 OrderedDict
from collections import OrderedDict
Run Code Online (Sandbox Code Playgroud)
您可以使用:
from copy import deepcopy
Run Code Online (Sandbox Code Playgroud)
解决问题
反而:
best_state = model.state_dict()
Run Code Online (Sandbox Code Playgroud)
你应该使用:
best_state = copy.deepcopy(model.state_dict())
Run Code Online (Sandbox Code Playgroud)
深(非浅)复制使可变 OrderedDict 实例不会best_state随着它发生变化。
您可以查看我关于在 PyTorch 中保存状态字典的其他答案。
| 归档时间: |
|
| 查看次数: |
864 次 |
| 最近记录: |