Tensorflow Estimator - warm_start_from和model_dir

mtn*_*gld 9 tensorflow tensorflow-estimator

当使用tf.estimatorwith warm_start_from model_dir,并且warm_start_from目录和model_dir目录都包含有效检查点时,哪个检查点将实际恢复?

为了给出一些上下文,我的估算器代码看起来像

est = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    warm_start_from=warm_start_dir)

for epoch in range(num_epochs):
    est.train(input_fn=train_input_fn)
    est.evaluate(input_fn=eval_input_fn)
Run Code Online (Sandbox Code Playgroud)

(输入函数使用一次迭代器.)

因此在第一次迭代期间,当model_dir空为空时,我希望加载热启动检查点,但在下一个时期,我希望从上一次迭代中model_dir加入中间微调检查点.但至少从日志中看,它看起来warm_start_dir仍在被加载.

我可能会覆盖我的下一次迭代的估算器,但我想知道它是否应该在估算器中构建一些如何.

小智 4

我遇到了类似的问题,我通过提供一个在会话启动时运行的初始化挂钩来解决这个问题,并使用tf.estimator.train_and_evaluate(尽管我不能对整个解决方案负责,因为我看到了类似的用于另一个目的的东西)别处):

class InitHook(tf.train.SessionRunHook):
    """initializes model from a checkpoint_path
    args:
        modelPath: full path to checkpoint
    """
    def __init__(self, checkpoint_dir):
        self.modelPath = checkpoint_dir
        self.initialized = False

    def begin(self):
        """
        Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
        """
        if not self.initialized:
            log = logging.getLogger('tensorflow')
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            if checkpoint is None:
                log.info('No pre-trained model is available, training from scratch.')
            else:
                log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
                tf.train.warm_start(checkpoint)
            self.initialized = True
Run Code Online (Sandbox Code Playgroud)

然后,进行训练:

initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
    input_fn = train_input_fn,
    max_steps = N_STEPS, 
    hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
    input_fn = eval_input_fn,
    steps = None,
    name = 'eval',
    throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)
Run Code Online (Sandbox Code Playgroud)

它在开始时运行一次,以初始化 中的变量warm_start_dir。稍后,当估计器中有新的检查点时model_dir,它会从那里继续warm_starting。