我是一个有张量流的初学者,所以请原谅这是一个愚蠢的问题,答案是显而易见的.
我创建了一个Tensorflow图,从X和y的占位符开始,我已经优化了一些代表我的模型的张量.图的一部分是可以计算预测矢量的东西,例如线性回归
y_model = tf.add(tf.mul(X,w),d)
y_vals = sess.run(y_model,feed_dict={....})
Run Code Online (Sandbox Code Playgroud)
训练结束后,我有w和d的可接受值,现在我想保存我的模型以供日后使用.然后,在另一个python会话中,我想恢复模型,以便我可以再次运行
## Starting brand new python session
import tensorflow as tf
## somehow restor the graph and the values here: how????
## so that I can run this:
y_vals = sess.run(y_model,feed_dict={....})
Run Code Online (Sandbox Code Playgroud)
对于某些不同的数据并取回y值.
我希望这种方式能够存储和恢复用于计算占位符的y值的图形 - 只要占位符获得正确的数据,这应该在没有用户(应用程序的用户)的情况下透明地工作.模型)需要知道图形是什么样的).
据我所知tf.train.Saver().save(..)只保存变量,但我也想保存图形.我认为tf.train.export_meta_graph在这里可能是相关的,但我不明白如何正确使用它,文档对我来说有点神秘,并且示例甚至不使用export_meta_graph.
从文档中,试试这个:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
Run Code Online (Sandbox Code Playgroud)
您可以指定路径.
如果要恢复模型,请尝试:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
saver.restore(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)
import tensorflow as tf
# Create some placeholder variables
x_pl = tf.placeholder(..., name="x")
y_pl = tf.placeholder(..., name="y")
# Add some operation to the Graph
add_op = tf.add(x, y)
with tf.Session() as sess:
# Add variable initializer
init = tf.global_variables_initializer()
# Add ops to save variables to checkpoints
# Unless var_list is specified Saver will save ALL named variables
# in Graph
# Optionally set maximum of 3 latest models to be saved
saver = tf.train.Saver(max_to_keep=3)
# Run variable initializer
sess.run(init)
for i in range(no_steps):
# Feed placeholders with some data and run operation
sess.run(add_op, feed_dict={x_pl: i+1, y_pl: i+5})
saver.save(sess, "path/to/checkpoint/model.ckpt", global_step=i)
Run Code Online (Sandbox Code Playgroud)
这将保存以下文件:
1)元图
.meta 文件:
MetaGraph的MetaGraphDef协议缓冲区表示,它保存完整的Tf Graph结构,即描述数据流的GraphDef以及与之关联的所有元数据,例如所有变量,操作,集合等.
导入图形结构将重新创建Graph及其所有变量,然后可以从检查点文件中恢复这些变量的相应值
如果您不想恢复Graph,但是您可以通过重新执行构建模型的Python代码来重建MetaGraphDef中的所有信息nb必须首先重新创建EXACT SAME变量,然后再从检查点恢复它们的值
由于并不总是需要Meta Graph文件,因此您可以关闭saver.save使用中的文件写入write_meta_graph=False
2)检查点文件
.data 文件:
tf.train.Saver()(默认为所有变量).index 文件:
描述所有张量及其元数据检查点文件的不可变表:
保存最新检查点文件的记录
import tensorflow as tf
latest_checkpoint = tf.train.latest_checkpoint("path/to/checkpoint")
# Load latest checkpoint Graph via import_meta_graph:
# - construct protocol buffer from file content
# - add all nodes to current graph and recreate collections
# - return Saver
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
# Start session
with tf.Session() as sess:
# Restore previously trained variables from disk
print("Restoring Model: {}".format("path/to/checkpoint"))
saver.restore(sess, latest_checkpoint)
# Retrieve protobuf graph definition
graph = tf.get_default_graph()
print("Restored Operations from MetaGraph:")
for op in graph.get_operations():
print(op.name)
# Access restored placeholder variables
x_pl = graph.get_tensor_by_name("x_pl:0")
y_pl = graph.get_tensor_by_name("y_pl:0")
# Access restored operation to re run
accuracy_op = graph.get_tensor_by_name("accuracy_op:0")
Run Code Online (Sandbox Code Playgroud)
这只是一个基础知识的简单示例,对于工作实现,请参见此处.
| 归档时间: |
|
| 查看次数: |
5793 次 |
| 最近记录: |