Tensorflow Data API - 预取

MPę*_*ski 12 prefetch tensorflow tensorflow-datasets

我正在尝试使用TF的新功能,即Data API,我不确定prefetch的工作原理.在下面的代码中

def dataset_input_fn(...)
    dataset = tf.data.TFRecordDataset(filenames, compression_type="ZLIB")
    dataset = dataset.map(lambda x:parser(...))
    dataset = dataset.map(lambda x,y: image_augmentation(...)
                      , num_parallel_calls=num_threads
                     )

    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)    
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
Run Code Online (Sandbox Code Playgroud)

在我上面的每一行之间有关系dataset=dataset.prefetch(batch_size)吗?或者也许它应该是在output_buffer_size数据集来自何时将要使用的每个操作之后tf.contrib.data

MPę*_*ski 12

在讨论github时, 我发现了mrry的评论:

请注意,在TF 1.4中,将有一个Dataset.prefetch()方法,可以更容易地在管道中的任何位置添加预取,而不仅仅是在map()之后.(您可以通过下载当前的每晚构建来尝试.)

例如,Dataset.prefetch()将启动后台线程以填充有序缓冲区,该缓冲区的作用类似于tf.FIFOQueue,因此下游管道阶段无需阻塞.但是,prefetch()实现要简单得多,因为它不需要支持与tf.FIFOQueue一样多的不同并发操作.

所以它意味着prefetch可以由任何命令放置,它适用于上一个命令.到目前为止,我已经注意到最大的性能提升仅仅是在最后.

在Dataset.map,Dataset.prefetch和Dataset.shuffle中还有一个关于buffer_size含义的讨论,其中mrry解释了有关预取和缓冲区的更多信息.

更新2018/10/01:

从版本1.7.0开始,Dataset API(在contrib中)有一个选项prefetch_to_device.请注意,此转换必须是管道中的最后一个,当TF 2.0到达时contrib将消失.要在多个GPU上进行预取工作,请使用MultiDeviceIterator(例如参见#13610)multi_device_iterator_ops.py.

https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/prefetch_to_device