Ale*_*lex 6 python structure tensorflow
我正在学习Tensorflow,并且正在尝试正确构建我的代码.我(或多或少)知道如何构建裸图或类方法图,但我想弄清楚如何最好地构造代码.我试过这个简单的例子:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
应该打印出来4.然而,当我这样做时,我得到错误:
Cannot interpret feed_dict key as Tensor: Tensor
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.
Run Code Online (Sandbox Code Playgroud)
我很确定这是因为函数内的占位符build_graph是私有的,但不应该with tf.Session(graph=graph)照顾它吗?在这种情况下使用feed dict有更好的方法吗?
Min*_*ark 16
有几种选择.
选项1:只传递张量的名称而不是张量本身.
with tf.Session(graph=graph) as sess:
feed = {"Placeholder:0": 3}
print(sess.run("Add:0", feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
在这种情况下,最好给节点赋予有意义的名称,而不是使用上面的默认名称:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {"a:0": 3}
print(sess.run("b:0", feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
回想一下,名为的操作的输出"foo"是名为"foo:0",, "foo:1"等等的张量.大多数操作只有一个输出.
选项2:使您的build_graph()函数返回所有重要节点.
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g, a, b
graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
选项3:向集合添加重要节点
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
for node in (a, b):
g.add_to_collection("important_stuff", node)
return g
graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
选项4:正如@pohe所建议你可以使用get_tensor_by_name()
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
我个人经常使用选项2,它非常简单,不需要玩名字.当图表很大并且将会存在很长时间时我使用选项3,因为集合与模型一起保存,并且它是记录真正重要内容的快速方法.我没有真正使用选项1,因为我更喜欢实际引用对象(不确定原因).当您使用由其他人构建的图表时,选项4非常有用,并且他们没有直接引用张量.
希望这可以帮助!
我也在寻找更好的方法,所以我的答案可能不是最好的。尽管如此,如果您给出a一个b名称,例如
a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')
Run Code Online (Sandbox Code Playgroud)
然后你可以做
graph = build_graph()
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Run Code Online (Sandbox Code Playgroud)
ps 命名a并b没有必要。只是以后参考起来比较方便。另外,如果您找到了更好的解决方案,也请分享。
| 归档时间: |
|
| 查看次数: |
3552 次 |
| 最近记录: |