如何在tensorflow lite模型中设置input_array和outout_array名称

fla*_*hen 5 tensorflow tensorflow-lite

操作系统平台和发行版:Linux Ubuntu 14.04

TensorFlow版本:来自二进制的Tensorflow(1.4.0)

CUDA / cuDNN版本:CUDA 8.0

我已经使用tensorflow训练了一些自定义模型,并试图使其成为移动应用程序的tensorflow lite模型。

我的模型定义如下:

def P_Net(inputs,label=None,bbox_target=None,landmark_target=None,training=True):
    #define common param
    with slim.arg_scope([slim.conv2d],
                        activation_fn=prelu,
                        weights_initializer=slim.xavier_initializer(),
                        biases_initializer=tf.zeros_initializer(),
                        weights_regularizer=slim.l2_regularizer(0.0005), 
                        padding='valid'):
        print inputs.get_shape()
        net = slim.conv2d(inputs, 28, 3, stride=1,scope='conv1')
......
        conv4_1 = slim.conv2d(net,num_outputs=2,kernel_size=[1,1],stride=1,scope='conv4_1',activation_fn=tf.nn.softmax)
        #conv4_1 = slim.conv2d(net,num_outputs=1,kernel_size=[1,1],stride=1,scope='conv4_1',activation_fn=tf.nn.sigmoid)

        print conv4_1.get_shape()
        #batch*H*W*4
        bbox_pred = slim.conv2d(net,num_outputs=4,kernel_size=[1,1],stride=1,scope='conv4_2',activation_fn=None)
        print bbox_pred.get_shape()
Run Code Online (Sandbox Code Playgroud)

其中conv4_1和conv4_2是输出层。

我冻结模型:

freeze_graph.freeze_graph('out_put_model/model.pb', '', False, model_path, 'Squeeze,Squeeze_1', '', '', 'out_put_model/frozen_model.pb', '', '')
Run Code Online (Sandbox Code Playgroud)

之后,我可以使用张量板查看图形。并将其读回以进行仔细检查,然后从检查点将事物身份输出到模型。

然后我尝试将Frozen_model.pb保存到tensorflow lite模型。找到tensorflow 1.4.0没有tensorflow lite模块,我从github检出tensorflow,然后bazel运行toco像这样:

bazel run --config=opt   //tensorflow/contrib/lite/toco:toco --   --input_file='/home/sens/mtcnn_cat/MTCNN-Tensorflow/test/out_put_model/frozen_model.pb'    --output_file='/home/sens/mtcnn_cat/MTCNN-Tensorflow/test/out_put_model/pnet.tflite'    --inference_type=FLOAT   --input_shape=1,128,128,3   --input_array=image_height,image_width,input_image   --output_array=Squeeze,Squeeze_1  --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --dump_graphviz=/tmp
Run Code Online (Sandbox Code Playgroud)

但是,输出抱怨找不到输出数组:

INFO: Running command line: bazel-bin/tensorflow/contrib/lite/toco/toco '--input_file=/home/sens/mtcnn_cat/MTCNN-Tensorflow/test/out_put_model/frozen_model.pb' '--output_file=/home/sens/mtcnn_cat/MTCNN-Tensorflow/test/out_put_model/pnet.tflite' '--inference_type=FLOAT' '--input_shape=1,128,128,3' '--input_array=image_height,image_width,input_image' '--output_array=Squeeze,Squeeze_1' '--input_format=TENSORFLOW_GRAPHDEF' '--output_format=TFLITE' '--dump_graphviz=/tmp'
2018-04-03 11:17:37.412589: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1172] Converting unsupported operation: Abs
2018-04-03 11:17:37.412660: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1172] Converting unsupported operation: Abs
2018-04-03 11:17:37.412699: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1172] Converting unsupported operation: Abs
2018-04-03 11:17:37.412880: F tensorflow/contrib/lite/toco/tooling_util.cc:686] Check failed: model.HasArray(output_array) Output array not found: Squeeze,Squeeze_1
Run Code Online (Sandbox Code Playgroud)

问题:1.如何设置--output_array=Squeeze,Squeeze_1参数?我认为这与张freeze_graph()量板上的输出节点相同,我确实找到了“ Squeeze”和“ Squeeze_1”节点在此处输入图片说明

  1. 如何设置--input_shape=1,128,128,3 --input_array=image_height,image_width,input_image参数?我检查并发现移动设备确实具有固定大小的图像输入,但是在我的模型中,输入图像和完全卷积输入没有固定大小,例如:

        self.image_op = tf.placeholder(tf.float32, name='input_image')
        self.width_op = tf.placeholder(tf.int32, name='image_width')
        self.height_op = tf.placeholder(tf.int32, name='image_height')
        image_reshape = tf.reshape(self.image_op, [1, self.height_op, self.width_op, 3])
    
    Run Code Online (Sandbox Code Playgroud)

并重塑为1 *宽*高* 3 在此处输入图片说明

那么如何将其写为输入形状呢?

Mah*_*esh 2

多亏了tensorflow,将冻结模型转换为tf_lite从来都不是一件容易的事。希望这段代码可以帮助您总结图表并帮助您找到输出和输入数组

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph={PATH_TO_FROZEN_GRAPH}/optimized_best.pb`
Run Code Online (Sandbox Code Playgroud)