在无法访问模型类代码的情况下保存 PyTorch 模型

Mic*_*l D 12 python deep-learning pytorch

如何在不需要在某处定义模型类的情况下保存 PyTorch 模型?


免责声明

PyTorch 中保存训练模型的最佳方式?没有保存的模型,而无需访问模型类代码解决方案(或工作液)。

nlm*_*lml 19

如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过TorchScript

我认为最简单的方法是使用trace = torch.jit.trace(model, typical_input)然后torch.jit.save(trace, path). 然后,您可以使用 加载跟踪模型torch.jit.load(path)

这是一个非常简单的例子。我们制作两个文件:

train.py

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
Run Code Online (Sandbox Code Playgroud)

infer.py

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))
Run Code Online (Sandbox Code Playgroud)

按顺序运行这些会得到结果:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])
Run Code Online (Sandbox Code Playgroud)

结果是一样的,所以我们很好。(请注意,由于 nn.Linear 层初始化的随机性,这里每次的结果都会不同)。

TorchScript 提供了将更复杂的架构和图形定义(包括 if 语句、while 循环等)保存在单个文件中,而无需在推理时重新定义图形。有关更高级的可能性,请参阅文档(上面链接)。

  • 主要的缺点是你仍然需要某种 pytorch 环境。另外,如果你想继续训练痕迹,我想那将是非常困难/不可能的。有时它也可能有点错误/难以调试。但这基本上是 pytorch 对在张量流中轻松保存整个图的回答。它随着每个版本的发布而不断改进,并且在我看来已经非常好了。 (3认同)