使用TensorFlow数据集进行批处理,重复和随机播放有什么作用?

blu*_*lue 15 dataset tensorflow

我目前正在学习TensorFlow,但我在这段代码中遇到了困惑:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
Run Code Online (Sandbox Code Playgroud)

我知道首先数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集做了什么?请给我一个例子的解释

Vla*_*-HC 20

想象一下,您有一个数据集:[1, 2, 3, 4, 5, 6],然后:

ds.shuffle()如何工作

dataset.shuffle(buffer_size=3)将分配大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。我们可以这样成像:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ?         ?
[1,2,3] <= [4,5,6]
Run Code Online (Sandbox Code Playgroud)

假设该条目2来自随机缓冲区。可用空间由源缓冲区中的下一个元素填充,即4

2 <= [1,3,4] <= [5,6]
Run Code Online (Sandbox Code Playgroud)

我们继续阅读,直到一无所有:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []
Run Code Online (Sandbox Code Playgroud)

ds.repeat()如何工作

从数据集中读取所有条目并尝试读取下一个元素后,数据集将引发错误。那就是ds.repeat()发挥作用的地方。它将重新初始化数据集,使其再次如下所示:

[1,2,3] <= [4,5,6]
Run Code Online (Sandbox Code Playgroud)

ds.batch()将产生什么

ds.batch()将采取第一batch_size项,使一批了出来。因此,示例数据集的批处理大小为3将产生两个批处理记录:

[2,1,5]
[3,6,4]
Run Code Online (Sandbox Code Playgroud)

由于我们要进行ds.repeat()批量处理,因此数据的生成将继续。但是,由于,元素的顺序将有所不同ds.random()。应该考虑的是6,由于随机缓冲区的大小,第一批中将永远不会出现这种情况。

  • @Seymour:顺序是 ds.shuffle(...).repeat().batch(..)。至少对于 TensorFlow 2.1.0 来说是这样。 (4认同)
  • 如果我不想打乱数据,因为数据是时间序列,我仍然可以使用重复和批量大小而不打乱吗? (2认同)
  • @alily,是的。那将是一个选择。另一种选择是使每个批次记录代表一个单独的时间序列记录。这样你就可以从使用 shuffle() 中受益。 (2认同)
  • 我不明白为什么第一批永远不会有 6 个?为什么不?批次不是随机抽取的吗?例如,第一批可能是 [2, 3, 6]。 (2认同)

小智 5

tf.Dataset 中的以下方法:

  1. repeat( count=0 )该方法重复数据集count的次数。
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)该方法对数据集中的样本进行打乱。的buffer_size是被随机化,并且返回作为样本的数目tf.Dataset
  3. batch(batch_size,drop_remainder=False)使用给定的批次大小创建数据集的批次,batch_size这也是批次的长度。