Sha*_*ang 5 python pytorch onnx
我已将 PyTorch 模型导出到 ONNX。现在,我有办法从 ONNX 模型获取输入层吗?
将 PyTorch 模型导出到 ONNX
import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")
Run Code Online (Sandbox Code Playgroud)
加载 ONNX 模型
onnx_model = onnx.load('onnx_model.onnx')
Run Code Online (Sandbox Code Playgroud)
我希望能够以某种方式从 onnx_model 获取输入层。这可能吗?
Onnx 库提供了 API 来提取所有输入的名称和形状,如下所示:
model = onnx.load(onnx_model)
inputs = {}
for inp in model.graph.input:
shape = str(inp.type.tensor_type.shape.dim)
inputs[inp.name] = [int(s) for s in shape.split() if s.isdigit()]
Run Code Online (Sandbox Code Playgroud)
小智 3
ONNX 模型是一个 protobuf 结构,如此处定义 ( https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto )。您可以使用为 python 生成的标准 protobuf 方法来使用它(请参阅: https: //developers.google.com/protocol-buffers/docs/reference/python- generated )。我不明白你到底想提取什么。但您可以迭代组成图的节点 ( model.graph.node )。图中的第一个节点可能对应也可能不对应于您可能认为的第一层(这取决于转换的完成方式)。您还可以获取模型的输入(model.graph.input)。
| 归档时间: |
|
| 查看次数: |
9118 次 |
| 最近记录: |