Ser*_*gey 5 python graph tensorflow
我正在尝试使用 tensorflow 进行研究,但我不明白如何打开和使用早期保存在文件中的类型为 tf.Graph 的图形。像这样的东西:
import tensorflow as tf
my_graph = tf.Graph()
with g.as_default():
x = tf.Variable(0)
b = tf.constant(-5)
k = tf.constant(2)
y = k*x + b
tf.train.write_graph(my_graph, '.', 'graph.pbtxt')
f = open('graph.pbtxt', "r")
# Do something with "f" to get my saved graph and use it below in
# tf.Session(graph=...) instead of dots
with tf.Session(graph=...) as sess:
tf.initialize_all_variables().run()
y1 = sess.run(y, feed_dict={x: 5})
y2 = sess.run(y, feed_dict={x: 10})
print(y1, y2)
Run Code Online (Sandbox Code Playgroud)
您必须加载文件内容,将其解析为 GraphDef,然后导入。它将被导入到当前图形中。您可能想用graph.as_default():上下文管理器包装它。
import tensorflow as tf
from tensorflow.core.framework import graph_pb2 as gpb
from google.protobuf import text_format as pbtf
gdef = gpb.GraphDef()
with open('my-graph.pbtxt', 'r') as fh:
graph_str = fh.read()
pbtf.Parse(graph_str, gdef)
tf.import_graph_def(gdef)
Run Code Online (Sandbox Code Playgroud)
我这样解决了这个问题:首先,我在图形“输出”中命名所需的计算,然后将该模型保存在下面的代码中......
import tensorflow as tf
x = tf.placeholder(dtype=tf.float64, shape=[], name="input")
a = tf.Variable(111, name="var1", dtype=tf.float64)
b = tf.Variable(-666, name="var2", dtype=tf.float64)
y = tf.add(x, a, name="output")
saver = tf.train.Saver()
with tf.Session() as sess:
tf.initialize_all_variables().run()
print(sess.run(y, feed_dict={x: 555}))
save_path = saver.save(sess, "model.ckpt", meta_graph_suffix='meta', write_meta_graph=True)
print("Model saved in file: %s" % save_path)
Run Code Online (Sandbox Code Playgroud)
其次,我需要在图中运行某些操作,我知道该操作的名称为“输出”。因此,我只需在另一个代码中恢复模型,并通过采用名称为“输入”和“输出”的必要图形部分来运行恢复的计算:
import tensorflow as tf
# Restore graph to another graph (and make it default graph) and variables
graph = tf.Graph()
with graph.as_default():
saver = tf.train.import_meta_graph("model.ckpt.meta")
y = graph.get_tensor_by_name("output:0")
x = graph.get_tensor_by_name("input:0")
with tf.Session() as sess:
saver.restore(sess, "model.ckpt")
print(sess.run(y, feed_dict={x: 888}))
# Variable out:
for var in tf.all_variables():
print("%s %.2f" % (var.name, var.eval()))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
11669 次 |
| 最近记录: |