我有一个使用迭代器训练我的网络的模型; 遵循Google现在推荐的新数据集API管道模型.
我读了tfrecord文件,将数据提供给网络,训练得很好,一切顺利,我在训练结束时保存了我的模型,所以我可以在以后运行推理.代码的简化版本如下:
""" Training and saving """
training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
next_training_element = training_iterator.get_next()
training_init_op = training_iterator.make_initializer(training_dataset)
def train(num_epochs):
# compute for the number of epochs
for e in range(1, num_epochs+1):
session.run(training_init_op) #initializing iterator here
while True:
try:
images, labels = session.run(next_training_element)
session.run(optimizer, feed_dict={x: images, y_true: labels})
except tf.errors.OutOfRangeError:
saver_name = './saved_models/ucf-model'
print("Finished Training Epoch {}".format(e))
break
""" Restoring """
# restoring the saved model and its variables …
Run Code Online (Sandbox Code Playgroud)