如何在python中读取.tflite模型层的参数

Jan*_*ane 6 machine-learning deep-learning tensorflow tensorflow-lite

我试图读取 tflite 模型并提取所有层的参数。

我的步骤:

  1. 我通过运行生成了 flatbuffers 模型表示(请先构建 flatc):

flatc -python tensorflow/tensorflow/lite/schema/schema.fbs

结果是tflite/包含图层描述文件 ( *.py) 和一些实用文件的文件夹。

  1. 我成功加载模型:

在导入错误的情况下:将 PYTHONPATH 设置为指向 tflite/ 所在的文件夹

from tflite.Model import Model
def read_tflite_model(file):
    buf = open(file, "rb").read()
    buf = bytearray(buf)
    model = Model.GetRootAsModel(buf, 0)
    return model
Run Code Online (Sandbox Code Playgroud)
  1. 我部分拉出模型和节点参数,并在节点上进行迭代:

模型部分:

def print_model_info(model):
        version = model.Version()
        print("Model version:", version)
        description = model.Description().decode('utf-8')
        print("Description:", description)
        subgraph_len = model.SubgraphsLength()
        print("Subgraph length:", subgraph_len)
Run Code Online (Sandbox Code Playgroud)

节点部分:

def print_nodes_info(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    operators_len = subgraph.OperatorsLength()
    print('Operators length:', operators_len)

    from collections import deque
    nodes = deque(subgraph.InputsAsNumpy())

    STEP_N = 0
    MAX_STEPS = operators_len
    print("Nodes info:")
    while len(nodes) != 0 and STEP_N <= MAX_STEPS:
        print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
        print("-" * 60)

        node_id = nodes.pop()
        print("Node id:", node_id)

        tensor = subgraph.Tensors(node_id)
        print("Node name:", tensor.Name().decode('utf-8'))
        print("Node shape:", tensor.ShapeAsNumpy())

        # which type is it? what does it mean?
        type_of_tensor = tensor.Type()
        print("Tensor type:", type_of_tensor)

        quantization = tensor.Quantization()
        min = quantization.MinAsNumpy()
        max = quantization.MaxAsNumpy()
        scale = quantization.ScaleAsNumpy()
        zero_point = quantization.ZeroPointAsNumpy()
        print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))

        # I do not understand it again. what is j, that I set to 0 here?
        operator = subgraph.Operators(0)
        for i in operator.OutputsAsNumpy():
            nodes.appendleft(i)

        STEP_N += 1

    print("-"*60)
Run Code Online (Sandbox Code Playgroud)

请指向我使用此 API 的文档或一些示例。

我的问题是:

  1. 我无法获得有关此 API 的文档

  2. 迭代 Tensor 对象对我来说似乎是不可能的,因为它没有 Inputs 和 Outputs 方法。+subgraph.Operators(j=0)我不明白 j 在这里是什么意思。因此,我的循环会经过两个节点:输入(一次)和下一个一遍又一遍。

  3. 迭代 Operator 对象肯定是可能的:

在这里,我们对它们进行了迭代,但我不知道如何映射 Operator 和 Tensor。

def print_in_out_info_of_all_operators(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    for i in range(subgraph.OperatorsLength()):
        operator = subgraph.Operators(i)
        print('Outputs', operator.OutputsAsNumpy())
        print('Inputs', operator.InputsAsNumpy())
Run Code Online (Sandbox Code Playgroud)
  1. 我不明白如何从 Operator 对象中提取参数。BuiltinOptions 方法给了我 Table 对象,我不知道要映射什么。

  2. subgraph = model.Subgraphs(0) 这个0是什么意思?它应该始终为零吗?显然不是,但它是什么?子图的ID?如果是这样 - 我很高兴。如果不是,请尝试解释。