Sen*_*t07 4 python machine-learning deep-learning tensorflow
我想只恢复张量流中计算图的一部分。我的架构包含两个网络。第一个网络的输出是第二个网络的输入。第一个网络是预先训练的,我想从检查点恢复。我也不想更新第一个网络的参数。我可以遵循一个例子来实现这一目标吗?
谢谢
我没有适合您任务的确切代码,但这里有一个简短的指南可以帮助您:
首先,您需要将网络解析为tf.GraphDef格式代码,应如下所示:
graph_def = tf.GraphDef()
with tf.gfile.FastGFile("path/to/graphdef") as f:
s = f.read()
graph_def.ParseFromString(s)
Run Code Online (Sandbox Code Playgroud)
或从检查点/saved_mode 恢复,然后转换为GraphDef:
tf.train.import_meta_graph('checkpoint.meta')
tf.get_default_graph().as_graph_def()
Run Code Online (Sandbox Code Playgroud)
现在你有了 graph_def
其次,提取graph_defwith的子图tf.graph_util.extract_sub_graph,您也可以指定目标节点,这些节点也是您第二个网络的输入。
最后,使用 导入第二步中的子图tf.import_graph_def。
另外,由于您不想更新第一个网络的参数,因此可以使用以下命令冻结其参数tf.graph_util.convert_variables_to_constants
| 归档时间: |
|
| 查看次数: |
1825 次 |
| 最近记录: |