如何让tf.data.Dataset返回一个调用中的所有元素?

Mil*_*lad 8 tensorflow tensorflow-datasets

是否有一种简单的方法可以获得整个元素集tf.data.Dataset?即我想将数据集的批量大小设置为我的数据集的大小,而不是特别传递元素的数量.这对于验证数据集很有用,我想一次性测量整个数据集的准确性.我很惊讶没有办法获得一个大小tf.data.Dataset

Abh*_*k S 6

在张量流2.0

您可以使用as_numpy_iterator枚举数据集

for element in Xtrain.as_numpy_iterator(): 
  print(element) 
Run Code Online (Sandbox Code Playgroud)


mus*_*rat 3

简而言之,没有一个好的方法来获取尺寸/长度;tf.data.Dataset是为数据管道而构建的,因此具有迭代器结构(根据我的理解并根据我对数据集操作代码的阅读。来自程序员指南

Atf.data.Iterator提供了从数据集中提取元素的主要方法。返回的操作Iterator.get_next()在执行时会生成数据集的下一个元素,并且通常充当输入管道代码和模型之间的接口。

而且,就其本质而言,迭代器没有方便的大小/长度概念;请参阅此处:在 Python 中获取迭代器中的元素数量

但更一般地说,为什么会出现这个问题?如果您调用batch,您也会得到一个tf.data.Dataset,因此无论您在批处理上运行什么,您都应该能够在整个数据集上运行;它将迭代所有元素并计算验证准确性。换句话说,我认为您实际上不需要尺寸/长度来完成您想做的事情。

  • 我的代码接受训练和验证 tfrecords 文件,并将它们转换为两个 tf.Datasets,其中一个迭代器可以初始化为两个数据集(类似于 TF 中的 [examples](https://www.tensorflow.org/programmers_guide/datasets)文档)。训练数据的纪元数和批量大小在我的控制范围内,我可以轻松地在训练数据集上应用 .batch() 和 .repeat() 方法。但是,对于验证数据,我想创建一个包含所有样本的批次,但我不一定知道 tfrecord 文件中有多少个样本。 (4认同)