dataset.map,Dataset.prefetch和Dataset.shuffle中buffer_size的含义

Ujj*_*wal 68 tensorflow tensorflow-gpu tensorflow-datasets

根据TensorFlow 文档,类prefetchmap方法tf.contrib.data.Dataset,都有一个名为的参数buffer_size.

对于prefetch方法,该参数称为buffer_size并且根据文档:

buffer_size:tf.int64标量tf.Tensor,表示预取时将被缓冲的最大元素数.

对于该map方法,该参数称为output_buffer_size并且根据文档:

output_buffer_size :(可选.)tf.int64标量tf.Tensor,表示将被缓冲的最大处理元素数.

类似地,对于该shuffle方法,出现相同的数量并且根据文档:

buffer_size:tf.int64标量tf.Tensor,表示新数据集将从中采样的数据集中的元素数.

这些参数之间有什么关系?

假设我创建一个Dataset对象如下:

 tr_data = TFRecordDataset(trainfilenames)
    tr_data = tr_data.map(providefortraining, output_buffer_size=10 * trainbatchsize, num_parallel_calls\
=5)
    tr_data = tr_data.shuffle(buffer_size= 100 * trainbatchsize)
    tr_data = tr_data.prefetch(buffer_size = 10 * trainbatchsize)
    tr_data = tr_data.batch(trainbatchsize)
Run Code Online (Sandbox Code Playgroud)

buffer上述代码段中的参数扮演了什么角色?

mrr*_*rry 109

TL; DR尽管名称相似,但这些论点的含义却截然不同.的buffer_sizeDataset.shuffle()可以影响你的数据集的随机性,因此在其中的元件所产生的顺序.该buffer_sizeDataset.prefetch()只影响它需要产生下一个元素的时间.


buffer_size在参数tf.data.Dataset.prefetch()output_buffer_size在参数tf.contrib.data.Dataset.map()调整提供的方式表现你的输入管道:两个参数告诉TensorFlow至多创造一个缓冲buffer_size的元素,和一个后台线程来填补背景缓冲区.(注意,我们output_buffer_sizeDataset.map()转移tf.contrib.data到的时候删除了参数tf.data.新代码应该使用Dataset.prefetch()后来map()获得相同的行为.)

添加预取缓冲区可以通过将数据预处理与下游计算重叠来提高性能.通常,在管道的最末端添加一个小的预取缓冲区(可能只有一个元素)是最有用的,但是更复杂的管道可以从额外的预取中受益,特别是当生成单个元素的时间可能变化时.

相反,影响变换随机性buffer_size论据.我们设计了转换(就像它替换的函数一样)来处理太大而无法放入内存的数据集.它不是对整个数据集进行混洗,而是维护元素的缓冲区,并从该缓冲区中随机选择下一个元素(将其替换为下一个输入元素,如果有的话).更改值会影响混洗的统一程度:如果大于数据集中元素的数量,则会得到统一的随机数; 如果它是那么你根本没有洗牌.对于非常大的数据集,典型的"足够好"的方法是在训练之前将数据随机地分成多个文件,然后均匀地混洗文件名,然后使用较小的shuffle缓冲区.但是,适当的选择取决于培训工作的确切性质.tf.data.Dataset.shuffle()Dataset.shuffle()tf.train.shuffle_batch()buffer_sizebuffer_sizebuffer_size1



Oli*_*rot 101

的重要性buffer_sizeshuffle()

我想跟进从@mrry以前的答案强调重要性buffer_sizetf.data.Dataset.shuffle().

在某些情况下,拥有一个低buffer_size不仅会给你带来低劣的洗牌:它可能会破坏你的整个训练.


一个实际的例子:猫分类器

例如,假设您正在训练图像上的猫分类器,并且您的数据按以下方式组织(10000每个类别中包含图像):

train/
    cat/
        filename_00001.jpg
        filename_00002.jpg
        ...
    not_cat/
        filename_10001.jpg
        filename_10002.jpg
        ...
Run Code Online (Sandbox Code Playgroud)

输入数据的标准方法tf.data可以是拥有文件名列表和相应标签列表,并用于tf.data.Dataset.from_tensor_slices()创建数据集:

filenames = ["filename_00001.jpg", "filename_00002.jpg", ..., 
             "filename_10001.jpg", "filename_10002.jpg", ...]
labels = [1, 1, ..., 0, 0...]  # 1 for cat, 0 for not_cat

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=1000)  # 1000 should be enough right?
dataset = dataset.map(...)  # transform to images, preprocess, repeat, batch...
Run Code Online (Sandbox Code Playgroud)

上面代码的一个大问题是数据集实际上不会以正确的方式进行洗牌.对于大约一个时代的前半部分,我们只会看到猫图像,而下半部分只会看到非猫图像.这会伤害训练很多.
在训练开始时,数据集将获取第一个1000文件名并将它们放入缓冲区,然后在它们中随机选择一个.由于所有第一1000张图像都是猫的图像,我们只会在开头选择猫图像.

这里的解决方法是确保buffer_size大于20000或提前洗牌filenameslabels(具有相同指数明显).

由于在内存中存储所有文件名和标签不是问题,我们实际上buffer_size = len(filenames)可以确保将所有内容混合在一起.确保tf.data.Dataset.shuffle()在应用重变换之前调用(例如读取图像,处理它们,批处理......).

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(buffer_size=len(filenames)) 
dataset = dataset.map(...)  # transform to images, preprocess, repeat, batch...
Run Code Online (Sandbox Code Playgroud)

要点是总是仔细检查洗牌会做什么.捕获这些错误的一个好方法可能是绘制批次随时间的分布(确保批次包含与训练集相同的分布,在我们的示例中包含半猫和一半非猫).

  • 谢谢.这是一个非常明确的答案:) (3认同)
  • 不是错别字:) 数据集每个类都有 10k 个图像,所以总缓冲区大小应该在 20k 以上。但是在上面的示例中,我采用了 1k 的缓冲区大小,这太低了。 (3认同)
  • 这种低缓冲区大小的问题在于,您的第一批中只会有猫。因此,该模型将轻松学会仅预测“猫”。训练网络的最佳方法是让批次具有相同数量的“猫”和“非猫”。 (2认同)

小智 6

import tensorflow as tf
def shuffle():
    ds = list(range(0,1000))
    dataset = tf.data.Dataset.from_tensor_slices(ds)
    dataset=dataset.shuffle(buffer_size=500)
    dataset = dataset.batch(batch_size=1)
    iterator = dataset.make_initializable_iterator()
    next_element=iterator.get_next()
    init_op = iterator.initializer
    with tf.Session() as sess:
        sess.run(init_op)
        for i in range(100):
            print(sess.run(next_element), end='')

shuffle()
Run Code Online (Sandbox Code Playgroud)

输出量

[298] [326] [2] [351] [92] [398] [72] [134] [404] [378] [238] [131] [369] [324] [35] [182] [441 ] [370] [372] [144] [77] [11] [199] [65] [346] [418] [493] [343] [444] [470] [222] [83] [61] [ 81] [366] [49] [295] [399] [177] [507] [288] [524] [401] [386] [89] [371] [181] [489] [172] [159] [195] [232] [160] [352] [495] [241] [435] [127] [268] [429] [382] [479] [519] [116] [395] [165] [233] ] [37] [486] [553] [111] [525] [170] [571] [215] [530] [47] [291] [558] [21] [245] [514] [103] [ 45] [545] [219] [468] [338] [392] [54] [139] [339] [448] [471] [589] [321] [223] [311] [234] [314]

  • 这表明对于迭代器产生的每个元素,缓冲区中都填充有数据集中之前不在缓冲区中的下一个元素。 (2认同)