nes*_*uno 12 python-3.x tensorflow tensorflow-datasets
假设我以这种方式定义了一个数据集:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
Run Code Online (Sandbox Code Playgroud)
如何获取数据集中的元素数量(因此,构成一个纪元的单个元素的数量)?
我知道tf.data.Dataset已经知道数据集的维度,因为该repeat()方法允许在指定的时期内重复输入管道。因此,它必须是获取此信息的一种方法。
Jac*_*sen 20
看看这里:https : //github.com/tensorflow/tensorflow/issues/26966
它不适用于 TFRecord 数据集,但适用于其他类型。
特尔;博士:
num_elements = tf.data.experimental.cardinality(dataset).numpy()
iru*_*yak 10
更新:
使用tf.data.experimental.cardinality(dataset)- 请参阅此处。
在 tensorflow 数据集的情况下,您可以使用_, info = tfds.load(with_info=True). 然后你可以打电话info.splits['train'].num_examples。但即使在这种情况下,如果您定义自己的拆分,它也无法正常工作。
因此,您可以计算文件数或迭代数据集(如其他答案中所述):
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
Run Code Online (Sandbox Code Playgroud)
小智 8
不幸的是,我认为 TF 中还没有这样的功能。但是,使用 TF 2.0 和 Eager Execution,您可以遍历数据集:
num_elements = 0
for element in dataset:
num_elements += 1
Run Code Online (Sandbox Code Playgroud)
这是我能想到的最有效的存储方式
这真的感觉像是很久以前就应该添加的功能。手指交叉,他们在以后的版本中添加了这个长度功能。
从 TensorFlow (>=2.3) 开始,可以使用:
dataset.cardinality().numpy()
Run Code Online (Sandbox Code Playgroud)
请注意,该.cardinality()方法已集成到主包中(在包中之前experimental)。
请注意,在应用filter()操作时,此操作可以返回 -2。
tf.data.Dataset.list_files创建一个称为张量MatchingFiles:0(如果适用,带有适当的前缀)。
你可以评估
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
Run Code Online (Sandbox Code Playgroud)
获取文件数。
当然,这仅适用于简单的情况,特别是如果每个图像只有一个样本(或已知数量的样本)。
在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能在 epoch 结束时观察样本数。
为此,您可以查看由您的Dataset. repeat()创建一个名为 的成员_count,它计算时代的数量。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算您的数据集大小。
这个计数器可能被埋在Dataset依次调用成员函数时创建的s的层次结构中,所以我们要这样挖出来。
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count
Run Code Online (Sandbox Code Playgroud)
请注意,使用此技术,数据集大小的计算并不准确,因为在此期间epoch_counter递增的批次通常混合来自两个连续 epoch 的样本。所以这个计算精确到你的批次长度。
len(list(dataset)) 可以在热切的模式下工作,尽管显然这不是一个好的通用解决方案。
小智 7
我看到了很多获取样本数量的方法,但实际上你可以通过以下方式轻松做到keras:
len(dataset) * BATCH_SIZE
Run Code Online (Sandbox Code Playgroud)
小智 5
您可以将其用于 TF2 中的 TFRecords:
ds = tf.data.TFRecordDataset(dataset_filenames)
ds_size = sum(1 for _ in ds)
Run Code Online (Sandbox Code Playgroud)
小智 5
这对我有用:
lengt_dataset = dataset.reduce(0, lambda x,_: x+1).numpy()
Run Code Online (Sandbox Code Playgroud)
它遍历您的数据集并增加 var x,它作为数据集的长度返回。