如何在 PyTorch 中保存模型架构?

M.Z*_*.Z. 5 pytorch

我知道我可以通过torch.save(model.state_dict(), FILE)或保存模型torch.save(model, FILE)。但是它们都没有保存模型的架构。

那么我们如何在 PyTorch 中保存模型的架构,就像在 Tensorflow 中创建.pb文件一样?我想对我的模型应用不同的调整。如果我不能保存模型的架构,我有什么比每次复制整个类定义并创建一个新类更好的方法吗?

Ros*_*osh 6

您可以参考这篇文章来了解如何保存分类器。要对模型进行调整,您可以创建一个新模型,该模型是现有模型的子模型。


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

Run Code Online (Sandbox Code Playgroud)

通过这种设置, newModel 具有 的所有层以及前向功能oldModelClass。如果需要进行调整,可以在函数中定义新层__init__,然后编写新的前向函数来定义它。


Sha*_*hai 6

保存所有参数 ( state_dict) 和所有模块是不够的,因为存在操纵张量的操作,但仅反映在特定实现的实际代码中(例如, reshapeResNet 中的 ing)。

此外,网络可能没有固定且预先确定的计算图:您可以认为具有分支或循环(循环)的网络。

因此,您必须保存实际的代码。

或者,如果网络中没有分支/循环,您可以保存计算图,例如,参见这篇文章

onnx您还应该考虑使用并具有捕获训练权重和计算图的表示形式来导出模型。