怎么把.pb转换成TFLite格式?

Ayu*_*ena 4 tensorflow tensorflow-lite

我下载了在Azure认知服务中训练过的模型的retrained_graph.pbretrained_labels.txt文件。现在,我想使用该模型制作一个Android应用程序,为此,我必须将其转换为TFLite格式。我使用了toco,但出现以下错误:

ValueError: Invalid tensors 'input' were found.
Run Code Online (Sandbox Code Playgroud)

我基本上是在按照本教程操作,在第4步中遇到问题,并直接复制粘贴了终端代码:https : //heartbeat.fritz.ai/neural-networks-on-mobile-devices-with-tensorflow-lite-a-tutorial- 85b41f53230c

Aji*_*kya 6

我在这里making测,也许您输入了input_arrays=input。可能不正确。使用此脚本查找冻结推理图的输入和输出数组的名称

import tensorflow as tf
gf = tf.GraphDef()   
m_file = open('frozen_inference_graph.pb','rb')
gf.ParseFromString(m_file.read())

with open('somefile.txt', 'a') as the_file:
    for n in gf.node:
        the_file.write(n.name+'\n')

file = open('somefile.txt','r')
data = file.readlines()
print "output name = "
print data[len(data)-1]

print "Input name = "
file.seek ( 0 )
print file.readline()
Run Code Online (Sandbox Code Playgroud)

就我而言,它们是:

output name: SemanticPredictions
input name: ImageTensor
Run Code Online (Sandbox Code Playgroud)


小智 5

您可以使用实用程序tflite_convert,它是tensorflow 1.10(或更高版本)软件包的一部分。

浮点推理的简单用法是这样的:

tflite_convert \
    --output_file=/tmp/retrained_graph.tflite \
    --graph_def_file=/tmp/retrained_graph.pb \
    --input_arrays=input \
    --output_arrays=output
Run Code Online (Sandbox Code Playgroud)

输入和输出-是张量流图的输入和输出张量


小智 5

import tensorflow as tf
gf = tf.GraphDef()
m_file = open('frozen_inference_graph.pb','rb')
for n in gf.node:
    print( n.name )
Run Code Online (Sandbox Code Playgroud)

第一个是 input_arrays 姓氏是 output_arrays(可能不止一个,取决于您模型的输出数量)

我的输出

  • image_tensor <--- input_array
  • 投掷
  • 预处理器/地图/形状预处理器/地图/strided_slice/stack
  • 预处理器/地图/strided_slice/stack_1
  • .
  • .
  • .
  • 后处理器/BatchMultiClassNonMaxSuppression/map/
  • TensorArrayStack_5/TensorArrayGatherV3
  • 后处理器/Cast_3
  • 后处理器/挤压
  • 添加/年
  • 添加
  • detection_boxes <---output_array
  • detection_scores <---output_array
  • detection_multiclass_scores
  • detection_classes <---output_array
  • num_detections <---output_array
  • raw_detection_boxes
  • raw_detection_scores