如何在tf.data.Dataset中填充固定的BATCH_SIZE?

min*_*ing 4 tensorflow tensorflow-datasets

我有一个包含11个样本的数据集.当我选择BATCH_SIZEbe 2时,以下代码将出错:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
Run Code Online (Sandbox Code Playgroud)

问题在于dataset = dataset.batch(batch_size),当Dataset循环进入最后一批时,剩余的样本数仅为1,那么有没有办法从以前访问过的样本中随机选取一个并生成最后一批?

Oli*_*rot 7

@mining通过填充文件名来提出解决方案.

另一种解决方案是使用tf.contrib.data.batch_and_drop_remainder.这将使用固定的批次大小批量处理数据并删除最后一个较小的批次.

在您的示例中,使用11个输入和批量大小为2,这将产生5批2个元素.

以下是文档中的示例:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
Run Code Online (Sandbox Code Playgroud)