使用Tensorflow检查点在C++中恢复模型

A. *_*iro 2 c++ python deep-learning tensorflow

我已经训练了一个我用Tensorflow使用Python实现的网络.最后,我用tf.train.Saver()保存了模型.现在我想用C++使用这个经过预先训练的网络进行预测.

我怎样才能做到这一点 ?有没有办法转换检查点,所以我可以使用tiny-dnn或Tensorflow C++?

欢迎任何想法:)谢谢!

ash*_*ash 8

您可能应该以SavedModel格式导出模型,该格式封装了计算图和保存的变量(tf.train.Saver仅保存变量,因此您无论如何都必须保存图形).

然后,您可以使用C++加载已保存的模型LoadSavedModel.

确切的调用取决于模型的输入和输出.但Python代码看起来像这样:

# You'd adjust the arguments here according to your model
signature = tf.saved_model.signature_def_utils.predict_signature_def(                                                                        
  inputs={'image': input_tensor}, outputs={'scores': output_tensor})                                                                         


builder = tf.saved_model.builder.SavedModelBuilder('/tmp/my_saved_model')                                                                    

builder.add_meta_graph_and_variables(                                                                                                        
   sess=sess,                                                                                                                    
   tags=[tf.saved_model.tag_constants.SERVING],                                                                                             
   signature_def_map={                                                                                                       
 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:                                                                
        signature                                                                                                                        
})                                                                                                                                       

builder.save()
Run Code Online (Sandbox Code Playgroud)

然后在C++中你会做这样的事情:

tensorflow::SavedModelBundle model;
auto status = tensorflow::LoadSavedModel(session_options, run_options, "/tmp/my_saved_model", {tensorflow::kSavedModelTagServe}, &model);
if (!status.ok()) {
   std::cerr << "Failed: " << status;
   return;
}
// At this point you can use model.session
Run Code Online (Sandbox Code Playgroud)

(请注意,使用SavedModel格式还允许您使用TensorFlow服务提供模型,如果这对您的应用程序有意义)

希望有所帮助.