如何在张量流中恢复部分图?

Sen*_*t07 4 python machine-learning deep-learning tensorflow

我想只恢复张量流中计算图的一部分。我的架构包含两个网络。第一个网络的输出是第二个网络的输入。第一个网络是预先训练的,我想从检查点恢复。我也不想更新第一个网络的参数。我可以遵循一个例子来实现这一目标吗?

谢谢

Jie*_*hou 5

我没有适合您任务的确切代码,但这里有一个简短的指南可以帮助您:

首先,您需要将网络解析为tf.GraphDef格式代码,应如下所示:

graph_def = tf.GraphDef()
with tf.gfile.FastGFile("path/to/graphdef") as f:
  s = f.read()
graph_def.ParseFromString(s)
Run Code Online (Sandbox Code Playgroud)

或从检查点/saved_mode 恢复,然后转换为GraphDef

tf.train.import_meta_graph('checkpoint.meta')
tf.get_default_graph().as_graph_def()
Run Code Online (Sandbox Code Playgroud)

现在你有了 graph_def

其次,提取graph_defwith的子图tf.graph_util.extract_sub_graph,您也可以指定目标节点,这些节点也是您第二个网络的输入。

最后,使用 导入第二步中的子图tf.import_graph_def

另外,由于您不想更新第一个网络的参数,因此可以使用以下命令冻结其参数tf.graph_util.convert_variables_to_constants