我正在保存我的会话状态:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
Run Code Online (Sandbox Code Playgroud)
当我稍后恢复时,我想获取我从中恢复的检查点的global_step的值.这是为了从中设置一些超参数.
执行此操作的hacky方法是运行并解析检查点目录中的文件名.但是,必须有一个更好的,内置的方式来做到这一点?
Yar*_*tov 25
一般模式是有一个global_step变量来跟踪步骤
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
Run Code Online (Sandbox Code Playgroud)
然后你可以保存
saver.save(sess, save_path, global_step=global_step)
Run Code Online (Sandbox Code Playgroud)
当你恢复,价值global_step恢复,以及
这有点骇人听闻,但其他答案对我根本没有用
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
Run Code Online (Sandbox Code Playgroud)
更新9/2017
我不确定这是否由于更新而开始起作用,但是以下方法似乎对于使global_step正确更新和加载有效:
创建两个操作。一个用于保存global_step,另一个用于对其进行递增:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
Run Code Online (Sandbox Code Playgroud)
现在,在您的训练循环中,每次您运行训练操作时都要运行增量操作。
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
Run Code Online (Sandbox Code Playgroud)
如果您想随时以整数形式检索全局步长值,只需在加载模型后使用以下命令:
sess.run(global_step)
Run Code Online (Sandbox Code Playgroud)
这对于创建文件名或计算当前时期是有用的,而无需第二个tensorflow变量来保存该值。例如,计算加载时的当前时间将类似于:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
21373 次 |
| 最近记录: |