在函数内部构建张量流图

Ale*_*lex 6 python structure tensorflow

我正在学习Tensorflow,并且正在尝试正确构建我的代码.我(或多或少)知道如何构建裸图或类方法图,但我想弄清楚如何最好地构造代码.我试过这个简单的例子:

def build_graph():                
     g = tf.Graph()     
     with g.as_default():                       
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g   

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}      
     print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

应该打印出来4.然而,当我这样做时,我得到错误:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.
Run Code Online (Sandbox Code Playgroud)

我很确定这是因为函数内的占位符build_graph是私有的,但不应该with tf.Session(graph=graph)照顾它吗?在这种情况下使用feed dict有更好的方法吗?

Min*_*ark 16

有几种选择.

选项1:只传递张量的名称而不是张量本身.

with tf.Session(graph=graph) as sess:
    feed = {"Placeholder:0": 3}      
    print(sess.run("Add:0", feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

在这种情况下,最好给节点赋予有意义的名称,而不是使用上面的默认名称:

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

回想一下,名为的操作的输出"foo"是名为"foo:0",, "foo:1"等等的张量.大多数操作只有一个输出.

选项2:使您的build_graph()函数返回所有重要节点.

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

选项3:向集合添加重要节点

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

选项4:正如@pohe所建议你可以使用get_tensor_by_name()

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

我个人经常使用选项2,它非常简单,不需要玩名字.当图表很大并且将会存在很长时间时我使用选项3,因为集合与模型一起保存,并且它是记录真正重要内容的快速方法.我没有真正使用选项1,因为我更喜欢实际引用对象(不确定原因).当您使用由其他人构建的图表时,选项4非常有用,并且他们没有直接引用张量.

希望这可以帮助!


poh*_*ohe 0

我也在寻找更好的方法,所以我的答案可能不是最好的。尽管如此,如果您给出a一个b名称,例如

a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')
Run Code Online (Sandbox Code Playgroud)

然后你可以做

graph = build_graph()

a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')

with tf.Session(graph=graph) as sess:
    feed = {a: 3}      
    print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)

ps 命名ab没有必要。只是以后参考起来比较方便。另外,如果您找到了更好的解决方案,也请分享。