使用convert_variables_to_constants保存tf.trainable_variables()

use*_*763 14 keras tensorflow

我有一个Keras模型,我想转换为Tensorflow protobuf(例如saved_model.pb).

该模型来自vgg-19网络上的传输学习,其中头部被切断并使用完全连接的+ softmax层进行训练,而其余的vgg-19网络被冻结

我可以keras.backend.get_session()在Keras中加载模型,然后使用在tensorflow中运行模型,生成正确的预测:

frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")

keras_prediction = keras_model.predict(frame)

print(keras_prediction)

with keras.backend.get_session() as sess:

    tvars = tf.trainable_variables()

    output = sess.graph.get_tensor_by_name('Softmax:0')
    input_tensor = sess.graph.get_tensor_by_name('input_1:0')

    tf_prediction = sess.run(output, {input_tensor: frame})
    print(tf_prediction) # this matches keras_prediction exactly
Run Code Online (Sandbox Code Playgroud)

如果我不包含该行tvars = tf.trainable_variables(),则该tf_prediction变量完全错误,并且根本不匹配输出keras_prediction.事实上,输出中的所有值(具有4个概率值的单个数组)完全相同(~0.25,全部加1).这让我怀疑,如果tf.trainable_variables()没有先调用头部的权重,只是初始化为0 ,这在检查模型变量后得到确认.在任何情况下,调用tf.trainable_variables()都会导致张量流预测正确.

问题是,当我尝试保存此模型时,来自的变量tf.trainable_variables()实际上并没有保存到.pb文件中:

with keras.backend.get_session() as sess:
    tvars = tf.trainable_variables()

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
    graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)
Run Code Online (Sandbox Code Playgroud)

我要问的是,如何将Keras模型保存为tf.training_variables()完整的Tensorflow protobuf ?

非常感谢!

Eli*_*xby 5

因此,冻结图表中变量(转换为常量)的方法应该可行,但不是必需的,并且比其他方法棘手。(有关详情,请参见下文)。如果您由于某种原因(例如导出到移动设备)而希望图形冻结,则我需要更多详细信息来帮助调试,因为我不确定Keras在您的图形背后进行了哪些隐式操作。但是,如果您只想稍后保存并加载图形,我可以解释如何执行此操作(尽管不能保证Keras所做的任何事情都不会搞砸……,很乐意帮助调试)。

因此,这里实际上有两种格式在起作用。一个是GraphDef,用于Checkpointing,因为它不包含有关输入和输出的元数据。另一个是,MetaGraphDef其中包含元数据和图形def,该元数据可用于预测和运行ModelServer(来自tensorflow / serving)。

无论哪种情况,您都需要做的不仅仅是调用,graph_io.write_graph因为变量通常存储在graphdef之外。

这两个用例都有包装器库。tf.train.Saver主要用于保存和还原检查点。

但是,由于您需要预测,因此建议您使用tf.saved_model.builder.SavedModelBuilder来构建SavedModel二进制文件。我为此提供了一些样板:

from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
  output = sess.graph.get_tensor_by_name('Softmax:0')
  input_tensor = sess.graph.get_tensor_by_name('input_1:0')
  sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
    {'input': input_tensor},
    {'output': output}
  )
  builder.add_meta_graph_and_variables(
      sess, tf.saved_model.tag_constants.SERVING,
      signature_def_map={
        DEFAULT_SIG_DEF: sig_def
      }
  )
builder.save()
Run Code Online (Sandbox Code Playgroud)

运行此代码后,您将拥有一个mymodel/saved_model.pb文件以及一个目录,mymodel/variables/其中包含与变量值相对应的protobuf。

然后再次加载模型,只需使用tf.saved_model.loader

# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
  meta_graph_def = tf.saved_model.loader.load(
      sess, 
      tf.saved_model.tag_constants.SERVING,
      './mymodel'
  )
  # From this point variables and graph structure are restored

  sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
  print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))
Run Code Online (Sandbox Code Playgroud)

显然,通过tensorflow / serving或Cloud ML Engine,此代码可提供更有效的预测,但这应该可行。Keras可能会在后台进行某些操作,这也会干扰此过程,如果是的话,我们希望了解一下(我想确保Keras用户也能够冻结图表,因此,如果您想将完整的代码或其他内容发送给我,也许我可以找到一个了解Keras的人来帮助我进行调试。)

编辑:您可以在这里找到一个端到端的示例:https : //github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/keras/trainer/model.py#L85