如何从 Tensorflow 中的冻结模型(pb 文件)中找到 output_node_names?

Som*_*hah 3 python tensorflow

我正在尝试将 freeze_model.pb 转换为 TensorFlow JS 兼容(.pb)文件,该文件基于 Tensorflow 的 SSD Mobilenet V2 COCO 预训练模型。我陷入了如何获取使用tensorflowjs_converter 时所需的output_node_names 参数的困境。如何知道输出节点名称?

我尝试使用下面的 Python 脚本获取操作名称,但无法理解哪一个是输出节点。

def load_graph(model_file):
  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())
  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

graph = load_graph('frozen_model.pb')
ops = graph.get_operations()
Run Code Online (Sandbox Code Playgroud)

小智 5

首先,您可以检查您的所有节点,graph_def如下所示:

for node in graph_def.node
    print(node.name)
Run Code Online (Sandbox Code Playgroud)

或者,如果您想直观地查看图形并确定将哪个节点用作输出,则可以使用 TensorBoard。有一个名为import_pb_to_tensorboard的工具。它本质上是使用几行将图形写入 a log_dir,您可以将张量板指向它。您只需将这些行复制到您自己的脚本中即可实现相同的效果,而无需从 TensorFlow 存储库进行构建。

第三,还有另一个工具叫summary_graph工具

bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/path/to/your/graph.pb
Run Code Online (Sandbox Code Playgroud)