您可以尝试使用跟踪将模型导出到TorchScript。这有局限性。由于 PyTorch 动态构建模型计算图的方式,如果您的模型中有任何控制流,那么导出的模型可能无法完全代表您的 Python 模块。TorchScript 仅在 PyTorch >= 1.0.0 中受支持,但我建议尽可能使用最新版本。
例如,没有任何条件行为的模型就可以
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fc = nn.Linear(20 * 4 * 4, 2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
x = self.fc(x.flatten(1))
return x
Run Code Online (Sandbox Code Playgroud)
我们可以将其导出如下
from torch import jit
net = Model()
# ... train your model
# put model in the mode you want to export (see bolded comment below)
net.eval()
# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))
# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')
Run Code Online (Sandbox Code Playgroud)
如果成功,那么我们可以将我们的模型加载到一个新的 python 脚本中,而无需使用Model
.
from torch import jit
net = jit.load('model.zip')
# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))
Run Code Online (Sandbox Code Playgroud)
加载的模型也是可训练的,但是,加载的模型只会在以. 例如,在这种情况下,我们以eval()
mode导出我们的模型,因此net.train()
在加载的模块上使用将不起作用。
像这样的模型,其行为在传递之间发生变化将无法正确导出。只有在此期间评估的代码jit.trace
才会被导出。
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fca = nn.Linear(20 * 4 * 4, 2)
self.fcb = nn.Linear(20 * 4 * 4, 2)
self.use_a = True
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
if self.use_a:
x = self.fca(x.flatten(1))
else:
x = self.fcb(x.flatten(1))
return x
Run Code Online (Sandbox Code Playgroud)
我们仍然可以导出模型如下
import torch
from torch import jit
net = Model()
# ... train your model
net.eval()
# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")
Run Code Online (Sandbox Code Playgroud)
在这种情况下,示例输出是
a: tensor([[-0.0959, 0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)
Run Code Online (Sandbox Code Playgroud)
然而,加载
import torch
from torch import jit
net = jit.load("model.ts")
# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
Run Code Online (Sandbox Code Playgroud)
结果是
a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
Run Code Online (Sandbox Code Playgroud)
请注意,分支“a”的逻辑不存在,因为它net.use_a
是False
何时jit.trace
被调用的。
这些限制可以克服,但需要您付出一些努力。您可以使用脚本功能来确保导出所有逻辑。
归档时间: |
|
查看次数: |
4753 次 |
最近记录: |