使用数据集固定大小批次(可能丢弃最后一批)

P-G*_*-Gn 5 python tensorflow

我想知道在使用时如何强制使用具有固定数量样本的批次Dataset

例如,

import numpy as np
import tensorflow as tf

dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

sess = tf.InteractiveSession()
try:
  while True:
    print(batch.eval().shape)
except tf.errors.OutOfRangeError:
  pass
Run Code Online (Sandbox Code Playgroud)

在此玩具示例中,数据总共有 101 个样本,我要求批次为 10 个样本。迭代时,最后一批的大小为 1,这是我想避免的。

在以前的(基于队列的)API 中,tf.train.batch有一个默认allow_smaller_final_batch设置为的参数。False我想用 重现这种行为Dataset

我想我可以使用Dataset.filter

dataset = tf.data.Dataset.range(101).batch(10)
  .filter(lambda x: tf.equal(tf.shape(x)[0], 10))
Run Code Online (Sandbox Code Playgroud)

但肯定应该有一些内置的方法来做到这一点?

Aak*_*ash 7

对于tensorflow>=2.0.0,您可以使用asdrop_remainder方法batch的参数tf.data.Dataset

dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True)
Run Code Online (Sandbox Code Playgroud)

drop_remainder如果最后一批在元素少于的情况下被删除,则参数设置BATCH_SIZE。默认值为 False。

我希望这对 2019+ 的读者有所帮助

  • 这基本上是[我的评论](/sf/ask/3317631201/潜在地-discarding-last-batch-using-dataset/62761036#comment90641358_47395718)的内容(并且它自 TF 1.10 如其中所述)。 (2认同)