连接两个不同图形张量流的输入和输出张量

Pra*_*hur 8 protocol-buffers tensorflow

我有2个ProtoBuf文件,我目前正在加载和转发,分别通过调用 -

out1=session.run(graph1out, feed_dict={graph1inp:inp1})
Run Code Online (Sandbox Code Playgroud)

其次是

final=session.run(graph2out, feed_dict={graph2inp:out1})
Run Code Online (Sandbox Code Playgroud)

其中graph1inpgraph1out是输入节点和输出节点图1为和类似的术语图2

现在,我想连接graph1outgraph2inp这样,我只需要运行graph2out而喂养graph1inpINP1.换句话说,连接2个相关图形的输入和输出张量,使得一次运行足以在两个训练的ProtoBuf文件上运行推理.

mrr*_*rry 12

假设你的Protobuf文件包含序列化的tf.GraphDefprotos,你可以使用input_map参数tf.import_graph_def()来连接两个图:

# Import graph1.
graph1_def = ...  # tf.GraphDef object
out1_name = "..."  # name of the graph1out tensor in graph1_def.
graph1out, = tf.import_graph_def(graph1_def, return_elements=[out_name])

# Import graph2 and connect it to graph1.
graph2_def = ...  # tf.GraphDef object
inp2_name = "..."  # name of the graph2inp tensor in graph2_def.
out2_name = "..."  # name of the graph2out tensor in graph2_def.
graph2out, = tf.import_graph_def(graph2_def, input_map={inp2_name: graph1out},
                                 return_elements=[out2_name])
Run Code Online (Sandbox Code Playgroud)


小智 8

接受的答案确实会连接两个图,但是不会还原集合,全局变量和可训练变量。经过详尽的搜索,我得出了一个更好的解决方案:

import tensorflow as tf
from tensorflow.python.framework import meta_graph

with tf.Graph().as_default() as graph1:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    ...
    output = tf.identity(input, name='output')

with tf.Graph().as_default() as graph2:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    ...
    output = tf.identity(input, name='output')

graph = tf.get_default_graph()
x = tf.placeholder(tf.float32, (None, 20), name='input')
Run Code Online (Sandbox Code Playgroud)

我们使用tf.train.export_meta_graph该导出还CollectionDef并将meta_graph.import_scoped_meta_graph其导入。这是发生连接的地方,特别是在input_map参数中。

meta_graph1 = tf.train.export_meta_graph(graph=graph1)
meta_graph.import_scoped_meta_graph(meta_graph1, input_map={'input': x}), import_scope='graph1',
out1 = graph.get_tensor_by_name('graph1/output:0')

meta_graph2 = tf.train.export_meta_graph(graph=graph2)
meta_graph.import_scoped_meta_graph(meta_graph2, input_map={'input': out1}, import_scope='graph2')
Run Code Online (Sandbox Code Playgroud)

现在图已连接,并且全局变量正在重新映射。

print(tf.global_variables())
Run Code Online (Sandbox Code Playgroud)

您也可以直接从文件导入元图。