如何在不重新定义模型的情况下在 PyTorch 中加载模型?

김수호*_*김수호 6 load model save pytorch

我正在寻找一种方法来保存 pytorch 模型,并在没有模型定义的情况下加载它。我的意思是我想保存我的模型,包括模型定义。

例如,我想要两个脚本。第一个将定义、训练和保存模型。第二个将加载和预测模型而不包括模型定义。

使用的方法torch.save(), torch.load()要求我在预测脚本中包含模型定义,但我想找到一种方法来加载模型而不在脚本中重新定义它。

jod*_*dag 7

您可以尝试使用跟踪将模型导出到TorchScript。这有局限性。由于 PyTorch 动态构建模型计算图的方式,如果您的模型中有任何控制流,那么导出的模型可能无法完全代表您的 Python 模块。TorchScript 仅在 PyTorch >= 1.0.0 中受支持,但我建议尽可能使用最新版本。

例如,没有任何条件行为的模型就可以

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc = nn.Linear(20 * 4 * 4, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        x = self.fc(x.flatten(1))
        return x
Run Code Online (Sandbox Code Playgroud)

我们可以将其导出如下

from torch import jit

net = Model()
# ... train your model

# put model in the mode you want to export (see bolded comment below)
net.eval()

# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))

# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')
Run Code Online (Sandbox Code Playgroud)

如果成功,那么我们可以将我们的模型加载到一个新的 python 脚本中,而无需使用Model.

from torch import jit
net = jit.load('model.zip')

# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))
Run Code Online (Sandbox Code Playgroud)

加载的模型也是可训练的,但是,加载的模型只会在以. 例如,在这种情况下,我们以eval()mode导出我们的模型,因此net.train()在加载的模块上使用将不起作用。


控制流

像这样的模型,其行为在传递之间发生变化将无法正确导出。只有在此期间评估的代码jit.trace才会被导出。

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fca = nn.Linear(20 * 4 * 4, 2)
        self.fcb = nn.Linear(20 * 4 * 4, 2)

        self.use_a = True

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        if self.use_a:
            x = self.fca(x.flatten(1))
        else:
            x = self.fcb(x.flatten(1))
        return x
Run Code Online (Sandbox Code Playgroud)

我们仍然可以导出模型如下

import torch
from torch import jit

net = Model()
# ... train your model

net.eval()

# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))

# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")
Run Code Online (Sandbox Code Playgroud)

在这种情况下,示例输出是

a: tensor([[-0.0959,  0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)
Run Code Online (Sandbox Code Playgroud)

然而,加载

import torch
from torch import jit

net = jit.load("model.ts")

# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
Run Code Online (Sandbox Code Playgroud)

结果是

a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
Run Code Online (Sandbox Code Playgroud)

请注意,分支“a”的逻辑不存在,因为它net.use_aFalse何时jit.trace被调用的。


脚本编写

这些限制可以克服,但需要您付出一些努力。您可以使用脚本功能来确保导出所有逻辑。