为什么在Tensorflow中使用和不使用上下文管理器定义tf.Session会导致不同的行为?

des*_*esa 5 python tensorflow

我注意到使用和不使用上下文管理器定义会话时会有所不同.这里有一个例子:

上下文管理器:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    x = tf.Variable(0)
    tf.summary.scalar("x", x)

with tf.Session(graph=graph) as sess:
    summaries = tf.summary.merge_all()
    print("Operations:", sess.graph.get_operations())
    print("\nSummaries:", summaries)
Run Code Online (Sandbox Code Playgroud)

结果是:

Operations: [<tf.Operation 'Variable/initial_value' type=Const>, <tf.Operation 'Variable' type=VariableV2>, <tf.Operation 'Variable/Assign' type=Assign>, <tf.Operation 'Variable/read' type=Identity>, <tf.Operation 'x/tags' type=Const>, <tf.Operation 'x' type=ScalarSummary>, <tf.Operation 'Merge/MergeSummary' type=MergeSummary>]

Summaries: Tensor("Merge/MergeSummary:0", shape=(), dtype=string)
Run Code Online (Sandbox Code Playgroud)

没有上下文管理器

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    x = tf.Variable(0)
    tf.summary.scalar("x", x)

sess = tf.Session(graph=graph)
summaries = tf.summary.merge_all()
print("Operations:", sess.graph.get_operations())
print("Summaries:", summaries)
sess.close()
Run Code Online (Sandbox Code Playgroud)

结果是:

Operations: [<tf.Operation 'Variable/initial_value' type=Const>, <tf.Operation 'Variable' type=VariableV2>, <tf.Operation 'Variable/Assign' type=Assign>, <tf.Operation 'Variable/read' type=Identity>, <tf.Operation 'x/tags' type=Const>, <tf.Operation 'x' type=ScalarSummary>]
Summaries: None
Run Code Online (Sandbox Code Playgroud)

为什么tf.summary.merge_all()找不到摘要?

Den*_*ers 5

你可以在tf.summary.merge_all() 这里找到实现.它通过调用此函数来工作,该函数从返回的图形中获取集合get_default_graph().该功能的文档如下:

"""Returns the default graph for the current thread.

The returned graph will be the innermost graph on which a
`Graph.as_default()` context has been entered, or a global default
graph if none has been explicitly created.

NOTE: The default graph is a property of the current thread. If you
create a new thread, and wish to use the default graph in that
thread, you must explicitly add a `with g.as_default():` in that
thread's function.

Returns:
    The default `Graph` being used in the current thread.
"""
Run Code Online (Sandbox Code Playgroud)

因此,在没有会话上下文管理器的代码中,问题不一定是您不在会话中; 问题是带有摘要的图形不是默认图形,并且您没有使用该图形输入上下文(如会话).

有一些不同的方法可以在不使用with tf.Session(graph=graph) as sess:上下文管理器的情况下"解决"这个问题:


一种选择是将摘要合并在一起,同时仍然具有graph默认图形:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    x = tf.Variable(0)
    tf.summary.scalar("x", x)
    summaries = tf.summary.merge_all()

with tf.Session(graph=graph) as sess:
    print("Operations:", sess.graph.get_operations())
    print("\nSummaries:", summaries)
Run Code Online (Sandbox Code Playgroud)

另一种选择是__enter__()在合并摘要之前显式地进行会话(这与with tf.Session(graph=graph) as sess:语句中python内部发生的内容非常相似):

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    x = tf.Variable(0)
    tf.summary.scalar("x", x)

sess = tf.Session(graph=graph)
sess.__enter__()
summaries = tf.summary.merge_all()
print("Operations:", sess.graph.get_operations())
print("Summaries:", summaries)
sess.close()
Run Code Online (Sandbox Code Playgroud)