blu*_*lue 15 dataset tensorflow
我目前正在学习TensorFlow,但我在这段代码中遇到了困惑:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
Run Code Online (Sandbox Code Playgroud)
我知道首先数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集做了什么?请给我一个例子的解释
Vla*_*-HC 20
想象一下,您有一个数据集:[1, 2, 3, 4, 5, 6],然后:
ds.shuffle()如何工作
dataset.shuffle(buffer_size=3)将分配大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。我们可以这样成像:
Random buffer
|
| Source dataset where all other elements live
| |
? ?
[1,2,3] <= [4,5,6]
Run Code Online (Sandbox Code Playgroud)
假设该条目2来自随机缓冲区。可用空间由源缓冲区中的下一个元素填充,即4:
2 <= [1,3,4] <= [5,6]
Run Code Online (Sandbox Code Playgroud)
我们继续阅读,直到一无所有:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
Run Code Online (Sandbox Code Playgroud)
ds.repeat()如何工作
从数据集中读取所有条目并尝试读取下一个元素后,数据集将引发错误。那就是ds.repeat()发挥作用的地方。它将重新初始化数据集,使其再次如下所示:
[1,2,3] <= [4,5,6]
Run Code Online (Sandbox Code Playgroud)
ds.batch()将产生什么
在ds.batch()将采取第一batch_size项,使一批了出来。因此,示例数据集的批处理大小为3将产生两个批处理记录:
[2,1,5]
[3,6,4]
Run Code Online (Sandbox Code Playgroud)
由于我们要进行ds.repeat()批量处理,因此数据的生成将继续。但是,由于,元素的顺序将有所不同ds.random()。应该考虑的是6,由于随机缓冲区的大小,第一批中将永远不会出现这种情况。
小智 5
tf.Dataset 中的以下方法:
repeat( count=0 )该方法重复数据集count的次数。shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)该方法对数据集中的样本进行打乱。的buffer_size是被随机化,并且返回作为样本的数目tf.Dataset。batch(batch_size,drop_remainder=False)使用给定的批次大小创建数据集的批次,batch_size这也是批次的长度。| 归档时间: |
|
| 查看次数: |
5189 次 |
| 最近记录: |