我有一个输入管道,可以在飞行中生成样本.我使用keras和自定义ImageDataGenerator以及相应的Iterator来获取内存中的样本.假设我的设置中的keras正在使用feed_dict(这个假设对我来说是一个问题)我想通过切换到raw tensorflow + Dataset.from_generator()来加快速度.
在这里,我看到在最近的Tensorflow中生成数据的输入管道的建议解决方案是使用Dataset.from_generator().
问题:
tf.contrib.data.Dataset.from_generator()通过将数据准备与培训重叠,新功能可以加快输入管道的速度.但是,只要有可能,您就可以通过切换到输入管道中的TensorFlow操作来获得最佳性能.
回答您的具体问题:
Keras TensorFlow后端用于tf.placeholder()表示已编译的函数输入,并将feed_dict参数传递给函数.
随着最近的优化tf.py_func()和feed_dict复制开销,我怀疑花费的时间memcpy()将是相同的.但是,您可以更轻松地使用Dataset.from_generator()以Dataset.prefetch()将一个批次的培训与下一批次的预处理重叠.
听起来你可以为预测阶段定义一个单独的迭代器.该tf.estimator.Estimator班确实通过实例化不同的"输入功能"与训练和评估不同的签名,然后建立为每个角色单独的图形类似的东西.
或者,您可以向训练迭代器(对于batch_z值)添加虚拟输出,并使用"可馈送迭代器"在训练和评估迭代器之间切换.
| 归档时间: |
|
| 查看次数: |
794 次 |
| 最近记录: |