PhA*_*ABC 9 python session class graph tensorflow
我相信我很难理解图形在张量流中如何工作以及如何访问它们.我的直觉是'with graph:'下的线条将图形形成为单个实体.因此,我决定创建一个在实例化时构建图形的类,并且拥有一个运行图形的函数,如下所示;
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
prediction = ...
cost = ...
optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(optimizer, feed_dict)
loss = sess.run(cost, feed_dict)
...
return variables
Run Code Online (Sandbox Code Playgroud)
接下来的步骤是创建一个主文件,它将汇集参数以传递给类,构建图形然后运行它;
#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }
#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...
Run Code Online (Sandbox Code Playgroud)
这对我来说非常优雅,但它显然不起作用(显然).实际上,似乎launchG函数无法访问图中定义的节点,这给出了我的错误;
---> 26 sess.run(optimizer, feed_dict)
NameError: name 'optimizer' is not defined
Run Code Online (Sandbox Code Playgroud)
也许这是我的python(和tensorflow)理解太有限了,但我的奇怪印象是,在创建图形(G)的情况下,使用此图形作为参数运行会话应该可以访问其中的节点,而无需要求我提供明确的访问权限.
任何启示?
Oli*_*rot 14
节点prediction,cost和optimizer是在方法中创建的局部变量__init__,它们无法在方法中访问launchG.
最简单的解决方法是将它们声明为类的属性Graph:
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
self.prediction = ...
self.cost = ...
self.optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(self.optimizer, feed_dict)
loss = sess.run(self.cost, feed_dict)
...
return variables
Run Code Online (Sandbox Code Playgroud)
您还可以使用他们的确切名称与检索图的节点graph.get_tensor_by_name和graph.get_operation_by_name.