如何在Tensorflow中恢复检查点时获取global_step?

Dan*_*ter 18 tensorflow

我正在保存我的会话状态:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
Run Code Online (Sandbox Code Playgroud)

当我稍后恢复时,我想获取我从中恢复的检查点的global_step的值.这是为了从中设置一些超参数.

执行此操作的hacky方法是运行并解析检查点目录中的文件名.但是,必须有一个更好的,内置的方式来做到这一点?

Yar*_*tov 25

一般模式是有一个global_step变量来跟踪步骤

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
Run Code Online (Sandbox Code Playgroud)

然后你可以保存

saver.save(sess, save_path, global_step=global_step)
Run Code Online (Sandbox Code Playgroud)

当你恢复,价值global_step恢复,以及

  • 这不起作用,每次我恢复训练时,global_step变量都会重置为0 (3认同)

Law*_* Du 6

这有点骇人听闻,但其他答案对我根本没有用

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
Run Code Online (Sandbox Code Playgroud)

更新9/2017

我不确定这是否由于更新而开始起作用,但是以下方法似乎对于使global_step正确更新和加载有效:

创建两个操作。一个用于保存global_step,另一个用于对其进行递增:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')
Run Code Online (Sandbox Code Playgroud)

现在,在您的训练循环中,每次您运行训练操作时都要运行增量操作。

sess.run([train_op,increment_global_step],feed_dict=feed_dict)
Run Code Online (Sandbox Code Playgroud)

如果您想随时以整数形式检索全局步长值,只需在加载模型后使用以下命令:

sess.run(global_step)
Run Code Online (Sandbox Code Playgroud)

这对于创建文件名或计算当前时期是有用的,而无需第二个tensorflow变量来保存该值。例如,计算加载时的当前时间将类似于:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
Run Code Online (Sandbox Code Playgroud)