tf.data.Dataset:如何获取数据集大小(历元中的元素数)?

nes*_*uno 12 python-3.x tensorflow tensorflow-datasets

假设我以这种方式定义了一个数据集:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
Run Code Online (Sandbox Code Playgroud)

如何获取数据集中的元素数量(因此,构成一个纪元的单个元素的数量)?

我知道tf.data.Dataset已经知道数据集的维度,因为该repeat()方法允许在指定的时期内重复输入管道。因此,它必须是获取此信息的一种方法。

Jac*_*sen 20

看看这里:https : //github.com/tensorflow/tensorflow/issues/26966

它不适用于 TFRecord 数据集,但适用于其他类型。

特尔;博士:

num_elements = tf.data.experimental.cardinality(dataset).numpy()


iru*_*yak 10

更新:

使用tf.data.experimental.cardinality(dataset)- 请参阅此处


在 tensorflow 数据集的情况下,您可以使用_, info = tfds.load(with_info=True). 然后你可以打电话info.splits['train'].num_examples。但即使在这种情况下,如果您定义自己的拆分,它也无法正常工作。

因此,您可以计算文件数或迭代数据集(如其他答案中所述):

num_training_examples = 0
num_validation_examples = 0

for example in training_set:
    num_training_examples += 1

for example in validation_set:
    num_validation_examples += 1
Run Code Online (Sandbox Code Playgroud)


小智 8

不幸的是,我认为 TF 中还没有这样的功能。但是,使用 TF 2.0 和 Eager Execution,您可以遍历数据集:

num_elements = 0
for element in dataset:
    num_elements += 1
Run Code Online (Sandbox Code Playgroud)

这是我能想到的最有效的存储方式

这真的感觉像是很久以前就应该添加的功能。手指交叉,他们在以后的版本中添加了这个长度功能。

  • 或者,在 TF 2.0 中添加更简洁的方法: count = dataset.reduce(0, lambda x, _: x + 1)` (7认同)

Tim*_*lin 8

从 TensorFlow (>=2.3) 开始,可以使用:

dataset.cardinality().numpy()
Run Code Online (Sandbox Code Playgroud)

请注意,该.cardinality()方法已集成到主包中(在包中之前experimental)。

请注意,在应用filter()操作时,此操作可以返回 -2。

  • `train_ds.cardinality().numpy()` 给了我 `-2`!!! (2认同)
  • 它给你 -2 因为你在代码中的某个地方使用了 .filter() (2认同)

P-G*_*-Gn 7

tf.data.Dataset.list_files创建一个称为张量MatchingFiles:0(如果适用,带有适当的前缀)。

你可以评估

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
Run Code Online (Sandbox Code Playgroud)

获取文件数。

当然,这仅适用于简单的情况,特别是如果每​​个图像只有一个样本(或已知数量的样本)。

在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能在 epoch 结束时观察样本数。

为此,您可以查看由您的Dataset. repeat()创建一个名为 的成员_count,它计算时代的数量。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算您的数据集大小。

这个计数器可能被埋在Dataset依次调用成员函数时创建的s的层次结构中,所以我们要这样挖出来。

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count
Run Code Online (Sandbox Code Playgroud)

请注意,使用此技术,数据集大小的计算并不准确,因为在此期间epoch_counter递增的批次通常混合来自两个连续 epoch 的样本。所以这个计算精确到你的批次长度。


mar*_*mus 7

len(list(dataset)) 可以在热切的模式下工作,尽管显然这不是一个好的通用解决方案。

  • 它违背了它作为迭代器的目的。调用 list() 一次性运行整个过程。它适用于较小量的数据,但对于较大的数据集可能会占用太多资源。 (11认同)
  • @yrekkehs 绝对,这就是为什么它不是一个好的通用解决方案。但它有效。 (2认同)

小智 7

我看到了很多获取样本数量的方法,但实际上你可以通过以下方式轻松做到keras

len(dataset) * BATCH_SIZE
Run Code Online (Sandbox Code Playgroud)


小智 5

您可以将其用于 TF2 中的 TFRecords:

ds = tf.data.TFRecordDataset(dataset_filenames)
ds_size = sum(1 for _ in ds)
Run Code Online (Sandbox Code Playgroud)


小智 5

这对我有用:

lengt_dataset = dataset.reduce(0, lambda x,_: x+1).numpy()
Run Code Online (Sandbox Code Playgroud)

它遍历您的数据集并增加 var x,它作为数据集的长度返回。