我想知道在使用时如何强制使用具有固定数量样本的批次Dataset。
例如,
import numpy as np
import tensorflow as tf
dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess = tf.InteractiveSession()
try:
while True:
print(batch.eval().shape)
except tf.errors.OutOfRangeError:
pass
Run Code Online (Sandbox Code Playgroud)
在此玩具示例中,数据总共有 101 个样本,我要求批次为 10 个样本。迭代时,最后一批的大小为 1,这是我想避免的。
在以前的(基于队列的)API 中,tf.train.batch有一个默认allow_smaller_final_batch设置为的参数。False我想用 重现这种行为Dataset。
我想我可以使用Dataset.filter:
dataset = tf.data.Dataset.range(101).batch(10)
.filter(lambda x: tf.equal(tf.shape(x)[0], 10))
Run Code Online (Sandbox Code Playgroud)
但肯定应该有一些内置的方法来做到这一点?
对于tensorflow>=2.0.0,您可以使用asdrop_remainder方法batch的参数tf.data.Dataset:
dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True)
Run Code Online (Sandbox Code Playgroud)
drop_remainder如果最后一批在元素少于的情况下被删除,则参数设置BATCH_SIZE。默认值为 False。
我希望这对 2019+ 的读者有所帮助
| 归档时间: |
|
| 查看次数: |
3852 次 |
| 最近记录: |