使用Dataset API生成平衡的迷你批次

lhl*_*mgr 9 tensorflow tensorflow-datasets

我对新数据集API(tensorflow 1.4rc1)有疑问.我有一个不平衡的数据集与标签01.我的目标是在预处理过程中创建平衡的迷你批次.

假设我有两个过滤的数据集:

ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()
Run Code Online (Sandbox Code Playgroud)

有没有办法组合这两个数据集,使得结果数据集如下所示ds = [0, 1, 0, 1, 0, 1]:

像这样的东西:

dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.apply(...)
# dataset looks like [0, 1, 0, 1, 0, 1, ...]
dataset = dataset.batch(20)
Run Code Online (Sandbox Code Playgroud)

我目前的做法是:

def _concat(x, y):
   return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.map(_concat)
Run Code Online (Sandbox Code Playgroud)

但我觉得有一种更优雅的方式.

提前致谢!

mrr*_*rry 6

您走在正确的轨道上。下面的示例用于Dataset.flat_map()将每对正例和负例转换为结果中的两个连续例:

dataset = tf.data.Dataset.zip((ds_pos, ds_neg))

# Each input element will be converted into a two-element `Dataset` using
# `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
# will flatten the resulting `Dataset`s into a single `Dataset`.
dataset = dataset.flat_map(
    lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(ex_pos).concatenate(
        tf.data.Dataset.from_tensors(ex_neg)))

dataset = dataset.batch(20)
Run Code Online (Sandbox Code Playgroud)

  • 我想对多类分类问题使用相同的方法。但是,它非常缓慢。在非二进制分类任务中,是否有更有效的方法来生产平衡的微型批次? (3认同)
  • 关于此解决方案要记住的一件事是,“ zip”将创建一个数据集,其大小等于正在压缩的最小数据集。我认为这就是为什么OP在`ds_neg`上使用`repeat`的原因。我认为这是为了确保所有多数类的数据ds_pos都将被使用。 (3认同)