在Tensorflow的数据集API中使用flat_map

knu*_*nub 4 tensorflow tensorflow-datasets

我正在使用数据集API,按如下所示读取数据:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
Run Code Online (Sandbox Code Playgroud)

现在,我想使用flat_map它来过滤掉一些样本,同时在训练时动态地复制一些其他样本(这是导致我的模型的输入函数)。

该API flat_map需要返回一个Dataset对象,但是我不知道如何创建该对象。这是我要实现的伪代码实现:

def flat_map_impl(tf_example):
    # Pseudo-code:
    # if tf_example["a"] == 1:
    #     return []
    # else:
    #     return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)
Run Code Online (Sandbox Code Playgroud)

如何在flat_map函数中实现呢?

注意:我想可以通过来实现这一点py_func,但我宁愿避免这种情况。

mrr*_*rry 7

tf.data.Dataset从a返回时创建a的最常见方法Dataset.flat_map()是使用Dataset.from_tensors()Dataset.from_tensor_slices()。在这种情况下,因为tf_example是字典,所以使用Dataset.from_tensors()and 的组合可能是最简单的Dataset.repeat(count)条件表达式用于计算count

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)
Run Code Online (Sandbox Code Playgroud)