小编Jan*_*ane的帖子

如何在python中读取.tflite模型层的参数

我试图读取 tflite 模型并提取所有层的参数。

我的步骤:

  1. 我通过运行生成了 flatbuffers 模型表示(请先构建 flatc):

flatc -python tensorflow/tensorflow/lite/schema/schema.fbs

结果是tflite/包含图层描述文件 ( *.py) 和一些实用文件的文件夹。

  1. 我成功加载模型:

在导入错误的情况下:将 PYTHONPATH 设置为指向 tflite/ 所在的文件夹

from tflite.Model import Model
def read_tflite_model(file):
    buf = open(file, "rb").read()
    buf = bytearray(buf)
    model = Model.GetRootAsModel(buf, 0)
    return model
Run Code Online (Sandbox Code Playgroud)
  1. 我部分拉出模型和节点参数,并在节点上进行迭代:

模型部分:

def print_model_info(model):
        version = model.Version()
        print("Model version:", version)
        description = model.Description().decode('utf-8')
        print("Description:", description)
        subgraph_len = model.SubgraphsLength()
        print("Subgraph length:", subgraph_len)
Run Code Online (Sandbox Code Playgroud)

节点部分:

def print_nodes_info(model):
    # what does this 0 mean? should it always be zero?
    subgraph = …
Run Code Online (Sandbox Code Playgroud)

machine-learning deep-learning tensorflow tensorflow-lite

6
推荐指数
0
解决办法
1598
查看次数