我使用tensorflow 1.5.1训练了一些模型,我有这些模型的检查点(包括.ckpt和.meta文件).现在我想使用这些文件在c ++中进行推理.
在python中,我将执行以下操作来保存和加载图形和检查点.保存:
images = tf.placeholder(...) // the input layer
//the graph def
output = tf.nn.softmax(net) // the output layer
tf.add_to_collection('images', images)
tf.add_to_collection('output', output)
Run Code Online (Sandbox Code Playgroud)
推理我恢复图形和检查点,然后从集合中恢复输入和输出层,如下所示:
meta_file = './models/last-100.meta'
ckpt_file = './models/last-100'
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, ckpt_file)
images = tf.get_collection('images')
output = tf.get_collection('output')
outputTensors = sess.run(output, feed_dict={images: np.array(an_image)})
Run Code Online (Sandbox Code Playgroud)
现在假设我像往常一样在python中进行保存,如何使用python中的简单代码在c ++中进行推理和恢复?
我找到了示例和教程,但对于tensorflow版本0.7 0.12,相同的代码不适用于1.5版本.我在tensorflow网站上找不到使用c ++ API恢复模型的教程.
为了这个线程.我会将我的评论改写成答案.
发布完整示例将需要CMake设置或将文件放入特定目录以运行bazel.因为我赞成第一种方式,它会破坏这篇文章的所有限制,以涵盖所有部分,我想重定向到C99,C++,GO的完整实现,没有Bazel,我测试了TF> v1.5.
在C++中加载图形并不比在Python中困难,因为您已经从源代码编译了TensorFlow.
首先创建一个MWE,创建一个非常好的转储网络图总是一个好主意,弄清楚事情是如何工作的:
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './exported/my_model')
Run Code Online (Sandbox Code Playgroud)
关于这一部分,关于SO可能有很多答案.所以我只是让它留在这里,没有进一步的解释.
在用其他语言做事之前,我们可以尝试在python中正确地做 - 在某种意义上:我们只需要用C++重写它.即使在python中恢复也很容易:
import tensorflow as tf
with tf.Session() as sess:
# load the computation graph
loader = tf.train.import_meta_graph('./exported/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, './exported/my_model')
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
Run Code Online (Sandbox Code Playgroud)
它没有用,因为大多数这些API端点在C++ API中都不存在(但是?).另一种版本是
import tensorflow as tf
with tf.Session() as sess:
metaGraph = tf.train.import_meta_graph('./exported/my_model.meta')
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
sess.run(restore_op, {filename_tensor_name: './exported/my_model'})
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
Run Code Online (Sandbox Code Playgroud)
不挂断.您可以随时使用print(dir(object))获取类似的属性restore_op_name....... 恢复模型是TensorFlow中的操作,就像其他所有操作一样.我们只是调用此操作并提供路径(字符串张量)作为输入.我们甚至可以编写自己的restore操作
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
Run Code Online (Sandbox Code Playgroud)
即使这看起来很奇怪,现在在C++中做同样的事情也很有帮助.
从通常的东西开始
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/public/session_options.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <string>
#include <iostream>
typedef std::vector<std::pair<std::string, tensorflow::Tensor>> tensor_dict;
int main(int argc, char const *argv[]) {
const std::string graph_fn = "./exported/my_model.meta";
const std::string checkpoint_fn = "./exported/my_model";
// prepare session
tensorflow::Session *sess;
tensorflow::SessionOptions options;
TF_CHECK_OK(tensorflow::NewSession(options, &sess));
// here we will put our loading of the graph and weights
return 0;
}
Run Code Online (Sandbox Code Playgroud)
您应该能够通过将其放入TensorFlow仓库并使用bazel或只需按照此处的说明使用CMake进行编译.
我们需要创建这样的meta_graph创建者tf.train.import_meta_graph.这可以通过
tensorflow::MetaGraphDef graph_def;
TF_CHECK_OK(ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def));
Run Code Online (Sandbox Code Playgroud)
在C++中从文件中读取的曲线图是不一样在Python导入的曲线图.我们需要在会话中创建此图表
TF_CHECK_OK(sess->Create(graph_def.graph_def()));
Run Code Online (Sandbox Code Playgroud)
通过查看restore上面的奇怪python 函数:
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
Run Code Online (Sandbox Code Playgroud)
我们可以用C++编写等效的代码
const std::string restore_op_name = graph_def.saver_def().restore_op_name()
const std::string filename_tensor_name = graph_def.saver_def().filename_tensor_name()
Run Code Online (Sandbox Code Playgroud)
有了这个,我们只需运行操作
sess->Run(feed_dict, // inputs
{}, // output_tensor_names (we do not need them)
{restore_op}, // target_node_names
nullptr) // outputs (there are no outputs this time)
Run Code Online (Sandbox Code Playgroud)
创建feed_dict可能是一个独立的帖子,这个答案已经足够长了.它只涵盖了最重要的东西.我想重定向到C99,C++,GO的完整实现,没有Bazel,我测试了TF> v1.5.这并不难 - 在普通C版本的情况下,它可能会变得非常长.