Joh*_*hsm 15 python tensorflow
我尝试简单地保存和恢复图形,但最简单的示例不能按预期工作(这是使用版本0.9.0或0.10.0在Linux 64上使用python 2.7或3.5.2在没有CUDA的情况下完成的)
首先我保存图形如下:
import tensorflow as tf
v1 = tf.placeholder('float32')
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])
Run Code Online (Sandbox Code Playgroud)
这将创建一个非空的文件"文件",并将g1设置为看起来像正确的图形定义的东西.
然后我尝试恢复此图:
import tensorflow as tf
g=tf.train.import_meta_graph("file")
Run Code Online (Sandbox Code Playgroud)
这没有错误,但根本不返回任何内容.
任何人都可以提供必要的代码,只需保存"v4"的图形并完全恢复它,以便在新的会话中运行它会产生相同的结果吗?
mrr*_*rry 30
要重复使用MetaGraphDef
,您需要在原始图表中记录有趣张量的名称.例如,在第一个程序中,name
在和的定义中设置显式参数v1
,v2
并且v4
:
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
# ...
v4 = tf.add(v3, c1, name="v4")
Run Code Online (Sandbox Code Playgroud)
然后,您可以在通话中使用原始图表中张量的字符串名称sess.run()
.例如,以下代码段应该有效:
import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
sess = tf.Session()
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
Run Code Online (Sandbox Code Playgroud)
或者,您可以使用tf.get_default_graph().get_tensor_by_name()
获取tf.Tensor
感兴趣的张量的对象,然后可以将其传递给sess.run()
:
import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
g = tf.get_default_graph()
v1 = g.get_tensor_by_name("v1:0")
v2 = g.get_tensor_by_name("v2:0")
v4 = g.get_tensor_by_name("v4:0")
sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
Run Code Online (Sandbox Code Playgroud)
更新:根据评论中的讨论,这里是保存和加载的完整示例,包括保存变量内容.这说明了通过vx
在单独的操作中将变量的值加倍来保存变量.
保存:
import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")
Run Code Online (Sandbox Code Playgroud)
恢复:
import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)
Run Code Online (Sandbox Code Playgroud)
最重要的是,为了使用已保存的模型,您必须记住至少一些节点的名称(例如,训练操作,输入占位符,评估张量等).该MetaGraphDef
商店包含在模型中,并有助于从检查点恢复这些,但你需要重建在训练中使用的张量/操作/评估自己模型的变量列表.