我试图读取 tflite 模型并提取所有层的参数。
我的步骤:
flatc -python tensorflow/tensorflow/lite/schema/schema.fbs
结果是tflite/包含图层描述文件 ( *.py) 和一些实用文件的文件夹。
在导入错误的情况下:将 PYTHONPATH 设置为指向 tflite/ 所在的文件夹
from tflite.Model import Model
def read_tflite_model(file):
buf = open(file, "rb").read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
return model
Run Code Online (Sandbox Code Playgroud)
模型部分:
def print_model_info(model):
version = model.Version()
print("Model version:", version)
description = model.Description().decode('utf-8')
print("Description:", description)
subgraph_len = model.SubgraphsLength()
print("Subgraph length:", subgraph_len)
Run Code Online (Sandbox Code Playgroud)
节点部分:
def print_nodes_info(model):
# what does this 0 mean? should it always be zero?
subgraph = …Run Code Online (Sandbox Code Playgroud)