Tensorflow:将模型保存到model.pb中,以便以后可视化

use*_*212 2 python deep-learning tensorflow

我找到以下代码片段来可视化已保存到*.pb文件中的模型:

model_filename ='saved_model.pb'
with tf.Session() as sess:
    with gfile.FastGFile(path_to_model_pb, 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
        LOGDIR='.'
        train_writer = tf.summary.FileWriter(LOGDIR)
        train_writer.add_graph(sess.graph)
Run Code Online (Sandbox Code Playgroud)

现在,我正在努力创建saved_model.pb。如果我的session.run看起来像这样:

  _, cr_loss = sess.run([train_op,cross_entropy_loss],
                         feed_dict={input_image: images,
                                    correct_label: gt_images,
                                    keep_prob:  KEEP_PROB,
                                    learning_rate: LEARNING_RATE}
                        )
Run Code Online (Sandbox Code Playgroud)

我如何保存图形所含train_opsaved_model.pb

jde*_*esa 7

最简单的方法是使用tf.train.write_graph。通常,您只需要执行以下操作:

tf.train.write_graph(my_graph, path_to_model_pb,
                     'saved_model.pb', as_text=False)
Run Code Online (Sandbox Code Playgroud)

my_graphtf.get_default_graph()如果您使用的是默认图形或任何其他tf.Graph(或tf.GraphDef)对象,则可以为。

请注意,这将保存图形定义,可视化它是可以的,但是如果您有变量,则除非先冻结图形,否则它们的值将不会保存在其中(因为这些值仅在会话对象中,而不是图形本身)。