我找到了两种在Tensorflow中保存模型的方法:tf.train.Saver()和SavedModelBuilder.但是,在第二种方式加载模型后,我找不到有关使用模型的文档.
注意:我想使用SavedModelBuilder方式,因为我在Python中训练模型并将在服务时使用另一种语言(Go),SavedModelBuilder在这种情况下似乎是唯一的方法.
这很好用tf.train.Saver()(第一种方式):
model = tf.add(W * x, b, name="finalnode")
# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")
# load
saver.restore(sess, "/tmp/model")
# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.
model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})
Run Code Online (Sandbox Code Playgroud)
tf.saved_model.builder.SavedModelBuilder()在自述文件中定义, 但在加载模型后tf.saved_model.loader.load(sess, [], export_dir),我找不到有关返回节点的文档(参见"finalnode"上面的代码)
Tho*_*mas 17
缺少的是 signature
# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"x": x},
outputs= {"finalnode": model})
})
builder.save()
# loading
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["tag"], export_dir)
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
model = graph.get_tensor_by_name("finalnode:0")
print(sess.run(model, {x: [5, 6, 7, 8]}))
Run Code Online (Sandbox Code Playgroud)
小智 5
这是使用simple_save加载和还原/预测模型的代码片段
#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
"isTrainingBool": isTraining},
outputs={"predictedClassBatch": predClass})
Run Code Online (Sandbox Code Playgroud)
请注意,使用simple_save会设置某些默认值(可以在以下网址查看:https : //github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)
现在,要还原和使用输入/输出字典:
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
with tf.Session() as sess:
model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.
inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)
inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)
isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)
outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)
outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})
print("predicted classes:", outPred)
Run Code Online (Sandbox Code Playgroud)
注意:需要默认的signature_def才能使用输入和输出字典中指定的张量名称。
| 归档时间: |
|
| 查看次数: |
14241 次 |
| 最近记录: |