Tensorflow 模型保存和加载

Sho*_*rma 5 python deep-learning conv-neural-network keras tensorflow

如何像我们在 keras 中所做的那样使用模型图保存张量流模型。我们可以保存整个模型(权重和图)并稍后导入,而不是在预测文件中再次定义整个图

在凯拉斯:

checkpoint = ModelCheckpoint('RightLane-{epoch:03d}.h5',monitor='val_loss', verbose=0,  save_best_only=False, mode='auto')
Run Code Online (Sandbox Code Playgroud)

将给出一个我们可以用于预测的 h5 文件

model = load_model("RightLane-030.h5")
Run Code Online (Sandbox Code Playgroud)

如何在本机张量流中做同样的事情

BiB*_*iBi 4

方法 1:将图和权重冻结在一个文件中(可能无法重新训练)

此选项显示如何将图形和权重保存在一个文件中。其预期用例是在训练后部署/共享模型。为此,我们将使用 protobuf (pb) 格式。

给定一个张量流会话(和图),您可以使用以下命令生成 protobuf

# freeze variables
output_graph_def = tf.graph_util.convert_variables_to_constants(
                               sess=sess,
                               input_graph_def =sess.graph.as_graph_def(),
                               output_node_names=['myMode/conv/output'])

# write protobuf to disk
with tf.gfile.GFile('graph.pb', "wb") as f:
    f.write(output_graph_def.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

其中output_node_names需要图形结果节点的名称字符串列表(参见tensorflow文档)。

然后,您可以加载 protobuf 并获取其权重的图表,以便轻松执行前向传递。

with tf.gfile.GFile(path_to_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')
    return graph
Run Code Online (Sandbox Code Playgroud)

方法2:恢复元图和检查点(轻松重新训练)

如果您希望能够继续训练模型,您可能需要恢复完整的图,即权重以及损失函数、一些梯度信息(例如 Adam 优化器)等。

使用时需要tensorflow生成的meta和checkpoint文件

saver = tf.train.Saver(...variables...)
saver.save(sess, 'my-model')
Run Code Online (Sandbox Code Playgroud)

这将生成两个文件,my-model以及my-model.meta.

从这两个文件中,您可以使用以下命令加载图表:

  new_saver = tf.train.import_meta_graph('my-model.meta')
  new_saver.restore(sess, 'my-model')
Run Code Online (Sandbox Code Playgroud)

更详细的内容可以查看官方文档