如何使用TensorFlow中的Estimator将模型存储在`.pb`文件中?

ink*_*kzk 1 python tensorflow

我用TensorFlow的估算器训练了我的模型.它似乎export_savedmodel应该用于制作.pb文件,但我真的不知道如何构建serving_input_receiver_fn.有人有什么想法吗?欢迎使用示例代码.

额外问题:

  1. .pb当我想重新加载模型时,我需要的唯一文件是什么?Variable不必要?

  2. 与adam优化器.pb相比,模型文件大小会减少多少.ckpt

Ghi*_*ADJ 7

如果你正在使用,你可以使用freeze_graph.py生成一个.pbfrom .ckpt+ ,然后你会在中找到这两个文件.pbtxttf.estimator.Estimatormodel_dir

python freeze_graph.py \
    --input_graph=graph.pbtxt \
    --input_checkpoint=model.ckpt-308 \
    --output_graph=output_graph.pb
    --output_node_names=<output_node>
Run Code Online (Sandbox Code Playgroud)
  1. 当我想重新加载模型时,.pb是我需要的唯一文件吗?变量多余?

是的,您还必须知道您是模型的输入节点和输出节点名称.然后使用import_graph_def加载.pb文件并使用获取输入和输出操作get_operation_by_name

  1. 与使用adam优化器的.ckpt相比,.pb会减少多少模型文件大小?

.pb文件不是压缩的.ckpt文件,因此没有"压缩率".

但是,有一种方法可以优化.pb文件以进行推理,这种优化可能会减少文件大小,因为它会删除仅培训操作的图形部分(请参阅此处的完整说明).

[评论]如何获取输入和输出节点名称?

您可以使用op name参数设置输入和输出节点名称.

要列出.pbtxt文件中的节点名称,请使用以下脚本.

import tensorflow as tf
from google.protobuf import text_format

with open('graph.pbtxt') as f:
    graph_def = text_format.Parse(f.read(), tf.GraphDef())

print [n.name for n in graph_def.node]
Run Code Online (Sandbox Code Playgroud)

[评论]我发现有一个tf.estimator.Estimator.export_savedmodel(),是直接在.pb中存储模型的函数吗?而我正在努力参与其中的参数serve_input_receiver_fn.有任何想法吗?

export_savedmodel()生成一个TensorFlow模型SavedModel通用序列化格式.它应该包含适合TensorFlow服务API所需的一切

serving_input_receiver_fn()是生成a所必需的东西的一部分SavedModel,它通过向图形添加占位符来确定模型的输入签名.

来自doc

此功能具有以下用途:

  • 要向图表添加占位符,服务系统将使用推理请求进行提供.
  • 添加将输入格式的数据转换为模型预期的功能所需的任何其他操作.

如果您以序列化形式tf.Examples(这是一种典型模式)收到推理请求,那么您可以使用文档中提供的示例.

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
  """An input receiver that expects a serialized tf.Example."""
  serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
Run Code Online (Sandbox Code Playgroud)

[评论]是否有想法在'.pb'中列出节点名称?

这取决于它是如何生成的.

如果它是一个SavedModel用途:

import tensorflow as tf

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        './saved_models/1519232535')
    print [n.name for n in meta_graph_def.graph_def.node]
Run Code Online (Sandbox Code Playgroud)

如果它是一个MetaGraph然后使用:

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    with gfile.FastGFile('model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        print [n.name for n in graph_def.node]
Run Code Online (Sandbox Code Playgroud)