mac*_*aut 3 tensorflow tensorflow-datasets
在训练期间进行检查点时(以防崩溃等),我保存图表和参数,但不清楚如何对tf.data用于输入的新对象执行相同的操作。
有没有一种直接的方法来检查这些,以便我可以继续当前的纪元,或恢复洗牌状态(也许从种子?)
该tf.contrib.data.make_saveable_from_iterator()函数接受一个tf.data.Iterator对象并返回一个“可保存对象”,可以使用tf.train.Saver. 它保存迭代器的整个状态,包括任何打乱的数据。
以下示例代码展示了如何将简单的迭代器添加到用于变量的同一检查点:
ds = tf.data.Dataset.range(10)
iterator = ds.make_initializable_iterator()
# [Build the training graph, using `iterator.get_next()` as the input.]
# Build the iterator SaveableObject.
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
# Add the SaveableObject to the SAVEABLE_OBJECTS collection so
# it will be saved automatically using a Saver.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
# Create a saver that saves all objects in the `tf.GraphKeys.SAVEABLE_OBJECTS`
# collection.
saver = tf.train.Saver()
with tf.Session() as sess:
while continue_training:
# [Perform training.]
if should_save_checkpoint:
saver.save(sess, ...)
Run Code Online (Sandbox Code Playgroud)
请注意,迭代器检查点支持当前(自 TensorFlow 1.8 起)处于实验状态,因此检查点格式可能会从一个版本更改为下一个版本。
| 归档时间: |
|
| 查看次数: |
1471 次 |
| 最近记录: |