来自Dataset的tf.train.MonitoredTrainingSession和reinitializable迭代器

Vik*_*man 12 tensorflow tensorflow-datasets

似乎MonitoredTrainingSession在第一次调用.run(..)之前做了一些操作(logging?),这意味着当我这样做时:

train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)

... no calls to ts.run ...

ts.run(init_train)
Run Code Online (Sandbox Code Playgroud)

这会产生错误:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element
Run Code Online (Sandbox Code Playgroud)

因此,在运行我提供的操作之前,它就像MonitoredTrainingSession正在执行某些操作一样接缝,从而无法使用来自Dataset的可重新初始化的迭代器进行togeather.

我相信我错过了一些东西,并希望听到:-)

Mic*_*n G 8

看起来在Tensorflow中还没有直接的解决方案.是的,他们没有完全支持Dataset API,这很奇怪.

原因是受监视的会话init_op在从检查点加载时跳过运行.因此,Iterator初始值设定项应该是局部变量.

目前的解决方案建议在本期中给出 - https://github.com/tensorflow/tensorflow/issues/12859

scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
                                     init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold, 
                                       checkpoint_dir=checkpoint_dir) as sess:
    while not sess.should_stop():
        sess.run(train_op)
Run Code Online (Sandbox Code Playgroud)