TensorFlow 数据集的函数 cache() 和 prefetch() 有何作用?

rtr*_*trt 8 caching dataset prefetch tensorflow tensorflow-datasets

我正在学习 TensorFlow 的图像分割教程。其中有以下几行:

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Run Code Online (Sandbox Code Playgroud)
  1. 该函数有什么cache()作用?官方文档相当晦涩且自引用:

缓存此数据集中的元素。

  1. 该函数有什么prefetch()作用?官方文档又相当晦涩难懂:

创建一个从该数据集中预取元素的数据集。

小智 16

转换tf.data.Dataset.cache可以在内存或本地存储中缓存数据集。这将避免在每个时期执行一些操作(例如文件打开和数据读取)。下一个纪元将重用缓存转换缓存的数据。

您可以在此处找到有关cache张量流的更多信息。

Prefetch重叠训练步骤的预处理和模型执行。当模型执行训练步骤 s 时,输入管道正在读取步骤 s+1 的数据。这样做可以减少训练的最大步骤时间(而不是总和)以及提取数据所需的时间。

您可以在此处找到有关prefetch张量流的更多信息。

希望这能回答您的问题。快乐学习。