有没有办法在 tf >= 2.4 的 GPU 上运行 tf.data API

Cha*_*G R 5 python tensorflow

令我惊讶的是,我找不到在 GPU 上运行 tf.data API 的简洁方法。据我了解,数据管道可以在 CPU 上运行,以便它可以并行发生(通过预取),从而允许 GPU 运行实际模型并对其进行训练。

然而,我的预处理是极其并行和计算密集型的。虽然从技术上讲我可以将预处理编写为模型中的第一层,但我真的不希望这样做以防止训练数据泄漏到我的模型中。

对此任何指示表示赞赏。我发现的最接近的是https://towardsdatascience.com/overcoming-data-preprocessing-bottlenecks-with-tensorflow-data-service-nvidia-dali-and-other-d6321917f851,其中涉及使用 nvidia DALI 框架。

以下是一些关键点:

  • 我已经尝试过使用 强制设备放置tf.device('...')
  • 我不想将数据预先导入到设备中,而是在 GPU 上运行整个数据管道。
  • 最好,如果我的计算更多,我想将我的数据集保存为tfrecords,以便我可以直接加载它。现在可以完成此操作tf.data.experimental.save,但它再次使用 CPU!