tensorflow数据集先混洗然后批处理

Lim*_*huo 6 tensorflow tensorflow-datasets

我最近开始学习张量流。

我不确定是否有区别

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)
Run Code Online (Sandbox Code Playgroud)

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)
Run Code Online (Sandbox Code Playgroud)

另外,我不确定为什么我不能使用

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
Run Code Online (Sandbox Code Playgroud)

因为它给出了错误

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'
Run Code Online (Sandbox Code Playgroud)

谢谢!

mrr*_*rry 7

TL; DR:是的,有所不同。几乎总是,您需要在之前致电。类上没有方法,您必须分别调用这两种方法以随机播放和批处理数据集。Dataset.shuffle() Dataset.batch()shuffle_batch()tf.data.Dataset


a的转换tf.data.Dataset以与调用相同的顺序应用。Dataset.batch()将其输入的连续元素合并为输出中的单个批处理元素。通过考虑以下两个数据集,我们可以看到操作顺序的影响:

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)
Run Code Online (Sandbox Code Playgroud)

在第一个版本中(洗牌前的批次),每个批次的元素是输入中的3个连续元素;而在第二个版本(批处理前的随机播放)中,则从输入中随机采样。通常,当通过小批量随机梯度下降(某种形式)进行训练时,应从总输入中尽可能均匀地对每个批次的元素进行采样。否则,网络可能会过度适合输入数据中的任何结构,并且最终的网络将无法获得如此高的精度。

  • 感谢您清晰的解释!我很困惑,因为一些在线教程在批处理之前先进行洗牌,而另一些教程在洗牌之前进行批处理。我猜那些在洗牌之前批处理的都是错误的。 (2认同)

R. *_*Zhu 6

完全同意@mrry,但存在一种情况,您可能希望改组之前进行批处理。假设您正在处理一些将输入 RNN 的文本数据。这里每个句子都被视为一个序列,一批将包含多个序列。由于句子的长度是可变的,我们需要将一批中的句子填充到统一的长度。一种有效的方法是通过批处理将长度相似的句子组合在一起,然后进行混洗。否则,我们最终可能会收到装满<pad>令牌的批次。