如何从tf.train.string_input_producer获取epoch num信息

all*_*len 6 tensorflow

如果使用string_input_producer读取文件,就像

filename_queue = tf.train.string_input_producer(
  files, 
  num_epochs=num_epochs,
  shuffle=shuffle)
Run Code Online (Sandbox Code Playgroud)

如何在训练期间获得epoch num信息(我想在训练期间显示此信息)我在下面尝试过

run 
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/epochs:0')
Run Code Online (Sandbox Code Playgroud)

将始终与限制纪元数相同.

run
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/CountUpTo:0')
Run Code Online (Sandbox Code Playgroud)

每次都会加1 ..

两者都无法在训练期间获得正确的纪元数.

另一件事是,如果从现有模型重新训练,我可以获得已经训练过的纪元数据吗?

chi*_*ger 1

我认为这里正确的方法是定义一个global_step传递给优化器的变量(或者您可以手动增加它)。

TensorFlow Mechanics 101教程提供了一个示例:

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
Run Code Online (Sandbox Code Playgroud)

现在global_step每次train_op运行都会增加。由于您知道数据集的大小和批量大小,因此您会知道当前处于哪个纪元。

当您使用 a 保存模型时tf.train.Saver()global_step变量也会被保存。当您恢复模型时,您只需调用即可global_step.eval()恢复您上次停止的步骤值。

我希望这有帮助!