如何从ONNX模型获取输入数据?

Sha*_*ang 5 python pytorch onnx

我已将 PyTorch 模型导出到 ONNX。现在,我有办法从 ONNX 模型获取输入层吗?

将 PyTorch 模型导出到 ONNX

import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")
Run Code Online (Sandbox Code Playgroud)

加载 ONNX 模型

onnx_model = onnx.load('onnx_model.onnx')
Run Code Online (Sandbox Code Playgroud)

我希望能够以某种方式从 onnx_model 获取输入层。这可能吗?

Aci*_*urn 5

Onnx 库提供了 API 来提取所有输入的名称和形状,如下所示:

model = onnx.load(onnx_model)
inputs = {}
for inp in model.graph.input:
    shape = str(inp.type.tensor_type.shape.dim)
    inputs[inp.name] = [int(s) for s in shape.split() if s.isdigit()]
Run Code Online (Sandbox Code Playgroud)


小智 3

ONNX 模型是一个 protobuf 结构,如此处定义 ( https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto )。您可以使用为 python 生成的标准 protobuf 方法来使用它(请参阅: https: //developers.google.com/protocol-buffers/docs/reference/python- generated )。我不明白你到底想提取什么。但您可以迭代组成图的节点 ( model.graph.node )。图中的第一个节点可能对应也可能不对应于您可能认为的第一层(这取决于转换的完成方式)。您还可以获取模型的输入(model.graph.input)。