Tensorflow:如何将.meta,.data和.index模型文件转换为一个graph.pb文件

Raf*_*fal 26 meta model graph checkpoint tensorflow

在tensorflow中,从头开始训练产生以下6个文件:

  1. events.out.tfevents.1503494436.06L7-BRM738
  2. model.ckpt-22480.meta
  3. 检查站
  4. model.ckpt-22480.data 00000-的-00001
  5. model.ckpt-22480.index
  6. graph.pbtxt

我想将它们(或仅需要的)转换为一个文件graph.pb,以便能够将其传输到我的Android应用程序.

我尝试了脚本,freeze_graph.py但它需要输入我还没有的input.pb文件.(我之前只提到过这6个文件).如何获取这个freezed_graph.pb文件?我看到几个线程,但没有一个为我工作.

vel*_*niy 33

您可以使用此简单脚本来执行此操作.但是您必须指定输出节点的名称.

import tensorflow as tf

meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('path/of/your/.meta/file'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

如果您不知道输出节点或节点的名称,有两种方法

  1. 您可以浏览图表并使用Netron或控制台summarize_graph实用程序查找名称.

  2. 您可以将所有节点用作输出节点,如下所示.

output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
Run Code Online (Sandbox Code Playgroud) (注意,您必须在`convert_variables_to_constants`调用之前放置此行.)

但我认为这是不寻常的情况,因为如果您不知道输出节点,则无法实际使用该图.

  • 有没有一种简单的方法来获取输出节点名称? (7认同)
  • 我收到此错误,可能是因为我不确定我的 output_node_names 是否正确。`文件“/path/to/saver.py”,第 1796 行,在恢复中引发 ValueError(“当它为 None 时无法加载 save_path。”)` (2认同)
  • 如果有人来到这里并遇到与我相同的问题,也就是在尝试冻结图形时,它将因“尝试使用未初始化的值”而失败,只需添加 `init=tf.global_variables_initializer() sess.run(init)`加载权重后。 (2认同)

小智 7

因为它可能对其他人有帮助,我也在回答github后回答这里;-).我想你可以尝试这样的东西(使用tensorflow/python/tools中的freeze_graph脚本):

python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
Run Code Online (Sandbox Code Playgroud)

这里的重要标志是--input_binary = false,因为文件graph.pbtxt是文本格式.我认为它对应于所需的graph.pb,它是二进制格式的等价物.

关于output_node_names,这对我来说真的很困惑,因为我在这部分仍然有一些问题,但你可以在tensorflow中使用summarize_graph脚本,它可以将pb或pbtxt作为输入.

问候,

斯蒂芬