如何用数据集迭代器替换已保存图形的输入,例如占位符?

Wes*_*sam 5 python tensorflow tensorflow-datasets

我有一个保存的Tensorflow图,它通过placeholder一个feed_dict参数来消耗输入.

sess.run(my_tensor, feed_dict={input_image: image})
Run Code Online (Sandbox Code Playgroud)

因为馈送数据与Dataset Iterator更有效的,我要加载保存的图形,更换input_image placeholderIterator和运行.我怎样才能做到这一点?有没有更好的方法呢?代码示例的答案将受到高度赞赏.

P-G*_*-Gn 7

您可以通过序列化图形并使用重新导入来实现这一点tf.import_graph_def,该图形具有input_map用于在所需位置插入输入的参数.

要做到这一点,你需要至少要知道你更换并要执行(相应的输出输入的名字xy我的例子).

import tensorflow as tf

# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')

# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
  print("with placeholder:")
  for i in range(10):
    print(sess.run(y, {x: i}))

# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()

tf.reset_default_graph()

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])

# enjoy Dataset inputs!
with tf.Session() as sess:
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        
Run Code Online (Sandbox Code Playgroud)

请注意,占位符节点仍然存在,因为我没有在这里打扰解析graph_def它 - 你可以删除它作为改进,虽然我认为它也可以留在这里.

根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需返回到a GraphDef).例如,如果从.meta文件加载图形,则可以使用tf.train.import_meta_graph接受相同input_map参数的图形.

import tensorflow as tf

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')

# enjoy Dataset inputs!
with tf.Session() as sess:
  # not needed here, but in practice you would also need to restore weights
  # restorer.restore(sess, weights_filepath)
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        
Run Code Online (Sandbox Code Playgroud)