如何优化推理一个简单,保存的TensorFlow 1.0.1图表?

tdu*_*ube 19 python python-2.7 tensorflow tensorflow-gpu

我无法optimize_for_inference在一个简单的,保存的TensorFlow图(Python 2.7;安装包pip install tensorflow-gpu==1.0.1)上成功运行该模块.

背景

保存TensorFlow图

这是我的Python脚本,用于生成并保存一个简单的图形,以便为我的输入x placeholder操作添加5 .

import tensorflow as tf

# make and save a simple graph
G = tf.Graph()
with G.as_default():
    x = tf.placeholder(dtype=tf.float32, shape=(), name="x")
    a = tf.Variable(5.0, name="a")
    y = tf.add(a, x, name="y")
    saver = tf.train.Saver()

with tf.Session(graph=G) as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(fetches=[y], feed_dict={x: 1.0})
    print(out)
    saver.save(sess=sess, save_path="test_model")
Run Code Online (Sandbox Code Playgroud)

恢复TensorFlow图

我有一个简单的恢复脚本,可以重新创建已保存的图形并恢复图形参数.保存/恢复脚本都生成相同的输出.

import tensorflow as tf

# Restore simple graph and test model output
G = tf.Graph()

with tf.Session(graph=G) as sess:
    # recreate saved graph (structure)
    saver = tf.train.import_meta_graph('./test_model.meta')
    # restore net params
    saver.restore(sess, tf.train.latest_checkpoint('./'))

    x = G.get_operation_by_name("x").outputs[0]
    y = G.get_operation_by_name("y").outputs
    out = sess.run(fetches=[y], feed_dict={x: 1.0})
    print(out[0])
Run Code Online (Sandbox Code Playgroud)

优化尝试

但是,虽然我对优化没有太多期待,但当我尝试优化图形进行推理时,我收到以下错误消息.预期的输出节点似乎不在保存的图形中.

$ python -m tensorflow.python.tools.optimize_for_inference --input test_model.data-00000-of-00001 --output opt_model --input_names=x --output_names=y  
Traceback (most recent call last):  
  File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main  
    "__main__", fname, loader, pkg_name)  
  File "/usr/lib/python2.7/runpy.py", line 72, in _run_code  
    exec code in run_globals  
  File "/{path}/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference.py", line 141, in <module>  
    app.run(main=main, argv=[sys.argv[0]] + unparsed)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run  
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/{path}/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference.py", line 90, in main  
    FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference_lib.py", line 91, in optimize_for_inference  
    placeholder_type_enum)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/tools/strip_unused_lib.py", line 71, in strip_unused  
    output_node_names)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 141, in extract_sub_graph  
    assert d in name_to_node_map, "%s is not in graph" % d  
AssertionError: y is not in graph  
Run Code Online (Sandbox Code Playgroud)

进一步的调查让我检查了保存图表的检查点,该图表只显示了1个张量(a,没有x和没有y).

(tf-1.0.1) $ python -m tensorflow.python.tools.inspect_checkpoint --file_name ./test_model --all_tensors
tensor_name:  a
5.0
Run Code Online (Sandbox Code Playgroud)

具体问题

  1. 为什么我没有看到x,并y在检查点?是因为它们是操作而不是张量?
  2. 由于我需要为optimize_for_inference模块提供输入和输出名称,如何构建图形以便我可以引用输入和输出节点?

vij*_*y m 46

以下是有关如何优化推理的详细指南:

optimize_for_inference模块将frozen binary GraphDef文件作为输入并输出optimized Graph Def可用于推理的文件.并且frozen binary GraphDef file需要使用模块freeze_graph,该模块将a GraphDef proto,a SaverDef proto和一组变量存储在检查点文件中.实现这一目标的步骤如下:

1.保存张量流图

 # make and save a simple graph
 G = tf.Graph()
 with G.as_default():
   x = tf.placeholder(dtype=tf.float32, shape=(), name="x")
   a = tf.Variable(5.0, name="a")
   y = tf.add(a, x, name="y")
   saver = tf.train.Saver()

with tf.Session(graph=G) as sess:
   sess.run(tf.global_variables_initializer())
   out = sess.run(fetches=[y], feed_dict={x: 1.0})

  # Save GraphDef
  tf.train.write_graph(sess.graph_def,'.','graph.pb')
  # Save checkpoint
  saver.save(sess=sess, save_path="test_model")
Run Code Online (Sandbox Code Playgroud)

2.冻结图表

python -m tensorflow.python.tools.freeze_graph --input_graph graph.pb --input_checkpoint test_model --output_graph graph_frozen.pb --output_node_names=y
Run Code Online (Sandbox Code Playgroud)

3.优化推理

python -m tensorflow.python.tools.optimize_for_inference --input graph_frozen.pb --output graph_optimized.pb --input_names=x --output_names=y
Run Code Online (Sandbox Code Playgroud)

4.使用优化图

with tf.gfile.GFile('graph_optimized.pb', 'rb') as f:
   graph_def_optimized = tf.GraphDef()
   graph_def_optimized.ParseFromString(f.read())

G = tf.Graph()

with tf.Session(graph=G) as sess:
    y, = tf.import_graph_def(graph_def_optimized, return_elements=['y:0'])
    print('Operations in Optimized Graph:')
    print([op.name for op in G.get_operations()])
    x = G.get_tensor_by_name('import/x:0')
    out = sess.run(y, feed_dict={x: 1.0})
    print(out)

#Output
#Operations in Optimized Graph:
#['import/x', 'import/a', 'import/y']
#6.0
Run Code Online (Sandbox Code Playgroud)

5.对于多个输出名称

如果有多个输出节点,则指定:output_node_names = 'boxes, scores, classes'并导入图形,

 boxes,scores,classes, = tf.import_graph_def(graph_def_optimized, return_elements=['boxes:0', 'scores:0', 'classes:0'])
Run Code Online (Sandbox Code Playgroud)

  • 很好的答案!您的评论非常宝贵,因为我没有意识到您必须单独保存图表和检查点.顺便说一句,我确实必须将`--input_checkpoint test_model`更改为`--input_checkpoint./ test_model`以使`freeze_graph`工作. (2认同)