hau*_*to3 5 python serialization pytorch torchvision
我在准备序列化 torchscript 模型的运算符列表时遇到错误。这里有什么问题呢?加载的时候就来了吗?
Python代码
# Dump list of operators used by MobileNetV2:
import torch, yaml
root = '/content/drive/My Drive/Monitoring/'
model = torch.jit.load(root+'model.pt')
ops = torch.jit.export_opnames(model)
with open('MobileNetV2.yaml', 'w') as output:
yaml.dump(ops, output)
Run Code Online (Sandbox Code Playgroud)
堆栈跟踪
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-39-8b61c35fb898> in <module>()
3
4 root = '/content/drive/My Drive/Monitoring/'
----> 5 model = torch.jit.load(root+'model.pt')
6 ops = torch.jit.export_opnames(model)
7 with open('MobileNetV2.yaml', 'w') as output:
/usr/local/lib/python3.6/dist-packages/torch/jit/_serialization.py in load(f, map_location, _extra_files)
159 cu = torch._C.CompilationUnit()
160 if isinstance(f, str) or isinstance(f, pathlib.Path):
--> 161 cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
162 else:
163 cpp_module = torch._C.import_ir_module_from_buffer(
RuntimeError: [enforce fail at inline_container.cc:222] . file not found: archive/constants.pkl
Run Code Online (Sandbox Code Playgroud)
您是否尝试过直接加载模型而不使用jit:
torch.load(model_path, map_location=device)
如果您想通过 加载模型torch.jit.load,则必须通过 保存模型torch.jit.trace,例如:
model = ...
model.eval()
input_tensor = torch.rand(1, 3, 224, 224)
script_model = torch.jit.trace(model, input_tensor)
script_model.save("model.jit.pt")
Run Code Online (Sandbox Code Playgroud)
您可能想检查此线程。