使用 tensorflow 数据集打乱输入文件

Pek*_*kka 2 python dataset tensorflow

使用旧的输入管道 API,我可以:

filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
Run Code Online (Sandbox Code Playgroud)

然后将文件名传递给其他队列,例如:

reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, n)
Run Code Online (Sandbox Code Playgroud)

如何使用 Dataset -API 实现类似的行为?

tf.data.TFRecordDataset()文件名的期望张量按固定顺序。

GPh*_*ilo 8

开始按顺序阅读它们,然后立即随机播放

BUFFER_SIZE = 1000 # arbitrary number
# define filenames somewhere, e.g. via glob
dataset = tf.data.TFRecordDataset(filenames).shuffle(BUFFER_SIZE)
Run Code Online (Sandbox Code Playgroud)

编辑:

这个问题的输入管道让我了解了如何使用 Dataset API 实现文件名改组:

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shuffle(BUFFER_SIZE) # doesn't need to be big
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(decode_example, num_parallel_calls=5) # add your decoding logic here
# further processing of the dataset
Run Code Online (Sandbox Code Playgroud)

这会将一个文件的所有数据放在下一个文件之前,依此类推。文件被打乱,但其中的数据将以相同的顺序生成。您也可以替换dataset.flat_mapinterleave同时处理多个文件并从每个文件中返回样本:

dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
Run Code Online (Sandbox Code Playgroud)

注意: interleave实际上并没有在多个线程中运行,它是一个循环操作。对于真正的并行处理,请参见parallel_interleave

  • 好的,但是当您有一长串包含相同标签(用于深度学习)的 TFRecord 文件(总共超过 50000 个示例),然后是另一系列包含带有另一个标签的示例的文件时,您会怎么做。要使改组工作,您需要一个大于 50000 的缓冲区,因此需要大量 RAM。这不是解决方案。改组文件名是一个更简单的解决方案。 (2认同)
  • @Pekka 我认为编辑可能是您的目标 (2认同)