Mic*_*n G 5 tensorflow tensorflow-datasets
Tensorflow程序员指南建议使用可馈送迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器.它主要需要喂食手柄以在它们之间进行选择.
如何与tf.train.MonitoredTrainingSession一起使用它?
以下方法失败并显示"RuntimeError:Graph已完成且无法修改".错误.
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
Run Code Online (Sandbox Code Playgroud)
如何同时实现MonitoredTrainingSession的便利性和迭代训练和验证数据集?
我从Tensorflow GitHub问题得到了答案 - https://github.com/tensorflow/tensorflow/issues/12859
解决方案是iterator.string_handle()在创建之前调用MonitoredSession.
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3170 次 |
| 最近记录: |