受监控的培训课程如何运作?

Jas*_*son 33 python tensorflow

我试图理解使用tf.Session和之间的差异tf.train.MonitoredTrainingSession,以及我可能更喜欢一个而不是另一个.似乎当我使用后者时,我可以避免许多"杂务",例如初始化变量,启动队列运行器或为汇总操作设置文件编写器.另一方面,通过受监控的培训会话,我无法指定我想明确使用的计算图.所有这些对我来说似乎都很神秘.这些类是如何创建的,我不理解这些背后的哲学吗?

pfm*_*pfm 31

我无法对如何创建这些类提供一些见解,但我认为这些内容与您如何使用它们有关.

tf.Session是python TensorFlow API中的低级对象,而正如您所说,tf.train.MonitoredTrainingSession它具有许多方便的功能,在大多数常见情况下尤其有用.

在描述一些好处之前tf.train.MonitoredTrainingSession,让我回答一下有关会话使用的图表的问题.您可以指定tf.Graph使用的MonitoredTrainingSession使用情况管理器with your_graph.as_default():

from __future__ import print_function
import tensorflow as tf

def example():
    g1 = tf.Graph()
    with g1.as_default():
        # Define operations and tensors in `g`.
        c1 = tf.constant(42)
        assert c1.graph is g1

    g2 = tf.Graph()
    with g2.as_default():
        # Define operations and tensors in `g`.
        c2 = tf.constant(3.14)
        assert c2.graph is g2

    # MonitoredTrainingSession example
    with g1.as_default():
        with tf.train.MonitoredTrainingSession() as sess:
            print(c1.eval(session=sess))
            # Next line raises
            # ValueError: Cannot use the given session to evaluate tensor:
            # the tensor's graph is different from the session's graph.
            try:
                print(c2.eval(session=sess))
            except ValueError as e:
                print(e)

    # Session example
    with tf.Session(graph=g2) as sess:
        print(c2.eval(session=sess))
        # Next line raises
        # ValueError: Cannot use the given session to evaluate tensor:
        # the tensor's graph is different from the session's graph.
        try:
            print(c1.eval(session=sess))
        except ValueError as e:
            print(e)

if __name__ == '__main__':
    example()
Run Code Online (Sandbox Code Playgroud)

所以,正如你所说,使用的好处MonitoredTrainingSession是,这个对象需要处理

  • 初始化变量,
  • 开始排队跑步者以及
  • 设置文件编写器,

但它也有使代码易于分发的好处,因为它的工作方式也不同,这取决于您是否将运行进程指定为主进程.

例如,您可以运行以下内容:

def run_my_model(train_op, session_args):
    with tf.train.MonitoredTrainingSession(**session_args) as sess:
        sess.run(train_op)
Run Code Online (Sandbox Code Playgroud)

您将以非分布式方式调用:

run_my_model(train_op, {})`
Run Code Online (Sandbox Code Playgroud)

或以分布式方式(有关输入的更多信息,请参阅分布式文档):

run_my_model(train_op, {"master": server.target,
                        "is_chief": (FLAGS.task_index == 0)})
Run Code Online (Sandbox Code Playgroud)

另一方面,使用原始tf.Session对象的好处是,您没有额外的好处tf.train.MonitoredTrainingSession,如果您不打算使用它们或者想要获得更多控制(例如,如何启动队列).

编辑(根据评论): 对于操作初始化,你必须做类似的事情(参见官方文档:

# Define your graph and your ops
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_p)
    sess.run(your_graph_ops,...)
Run Code Online (Sandbox Code Playgroud)

对于QueueRunner,我会向您推荐官方文档,您将在其中找到更完整的示例.

EDIT2:

了解如何tf.train.MonitoredTrainingSession运作的主要概念是_WrappedSession类:

此包装器用作各种会话包装器的基类,这些包装器提供监视,协调和恢复等附加功能.

这样的tf.train.MonitoredTrainingSession作品(1.1版本):

  • 它首先检查它是否是主管或工人(参见词汇问题的分布式文档).
  • 它开始提供的钩子(例如,在这个阶段StopAtStepHook只检索global_step张量).
  • 它创建一个会话,它将一个Chief(或Worker会话)包裹成一个_HookedSession包裹成一个_CoordinatedSession包裹成一个_RecoverableSession.
    Chief/ Worker会话负责运行由提供初始化OPS的Scaffold.
      scaffold: A `Scaffold` used for gathering or building supportive ops. If
    not specified a default one is created. It's used to finalize the graph.
    
    Run Code Online (Sandbox Code Playgroud)
  • chief会议还负责所有关卡部分:使用从检查点恢复,例如SaverScaffold.
  • _HookedSession基本上是有来装饰run方法:调用_call_hook_before_runafter_run方法时相关.
  • 在创建时,_CoordinatedSession构建a Coordinator启动队列运行程序并负责关闭它们.
  • _RecoverableSession会确保有重试的情况下tf.errors.AbortedError.

总之,tf.train.MonitoredTrainingSession避免了很多锅炉板代码,同时可以用钩子机构轻松扩展.