Jan*_*ane 6 machine-learning deep-learning tensorflow tensorflow-lite
我试图读取 tflite 模型并提取所有层的参数。
我的步骤:
flatc -python tensorflow/tensorflow/lite/schema/schema.fbs
结果是tflite/包含图层描述文件 ( *.py) 和一些实用文件的文件夹。
在导入错误的情况下:将 PYTHONPATH 设置为指向 tflite/ 所在的文件夹
from tflite.Model import Model
def read_tflite_model(file):
buf = open(file, "rb").read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
return model
Run Code Online (Sandbox Code Playgroud)
模型部分:
def print_model_info(model):
version = model.Version()
print("Model version:", version)
description = model.Description().decode('utf-8')
print("Description:", description)
subgraph_len = model.SubgraphsLength()
print("Subgraph length:", subgraph_len)
Run Code Online (Sandbox Code Playgroud)
节点部分:
def print_nodes_info(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
operators_len = subgraph.OperatorsLength()
print('Operators length:', operators_len)
from collections import deque
nodes = deque(subgraph.InputsAsNumpy())
STEP_N = 0
MAX_STEPS = operators_len
print("Nodes info:")
while len(nodes) != 0 and STEP_N <= MAX_STEPS:
print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
print("-" * 60)
node_id = nodes.pop()
print("Node id:", node_id)
tensor = subgraph.Tensors(node_id)
print("Node name:", tensor.Name().decode('utf-8'))
print("Node shape:", tensor.ShapeAsNumpy())
# which type is it? what does it mean?
type_of_tensor = tensor.Type()
print("Tensor type:", type_of_tensor)
quantization = tensor.Quantization()
min = quantization.MinAsNumpy()
max = quantization.MaxAsNumpy()
scale = quantization.ScaleAsNumpy()
zero_point = quantization.ZeroPointAsNumpy()
print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))
# I do not understand it again. what is j, that I set to 0 here?
operator = subgraph.Operators(0)
for i in operator.OutputsAsNumpy():
nodes.appendleft(i)
STEP_N += 1
print("-"*60)
Run Code Online (Sandbox Code Playgroud)
请指向我使用此 API 的文档或一些示例。
我的问题是:
我无法获得有关此 API 的文档
迭代 Tensor 对象对我来说似乎是不可能的,因为它没有 Inputs 和 Outputs 方法。+subgraph.Operators(j=0)我不明白 j 在这里是什么意思。因此,我的循环会经过两个节点:输入(一次)和下一个一遍又一遍。
迭代 Operator 对象肯定是可能的:
在这里,我们对它们进行了迭代,但我不知道如何映射 Operator 和 Tensor。
def print_in_out_info_of_all_operators(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
for i in range(subgraph.OperatorsLength()):
operator = subgraph.Operators(i)
print('Outputs', operator.OutputsAsNumpy())
print('Inputs', operator.InputsAsNumpy())
Run Code Online (Sandbox Code Playgroud)
我不明白如何从 Operator 对象中提取参数。BuiltinOptions 方法给了我 Table 对象,我不知道要映射什么。
subgraph = model.Subgraphs(0)
这个0是什么意思?它应该始终为零吗?显然不是,但它是什么?子图的ID?如果是这样 - 我很高兴。如果不是,请尝试解释。
| 归档时间: |
|
| 查看次数: |
1598 次 |
| 最近记录: |