knu*_*nub 4 tensorflow tensorflow-datasets
我正在使用数据集API,按如下所示读取数据:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
Run Code Online (Sandbox Code Playgroud)
现在,我想使用flat_map它来过滤掉一些样本,同时在训练时动态地复制一些其他样本(这是导致我的模型的输入函数)。
该API flat_map需要返回一个Dataset对象,但是我不知道如何创建该对象。这是我要实现的伪代码实现:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
Run Code Online (Sandbox Code Playgroud)
如何在flat_map函数中实现呢?
注意:我想可以通过来实现这一点py_func,但我宁愿避免这种情况。
tf.data.Dataset从a返回时创建a的最常见方法Dataset.flat_map()是使用Dataset.from_tensors()或Dataset.from_tensor_slices()。在这种情况下,因为tf_example是字典,所以使用Dataset.from_tensors()and 的组合可能是最简单的Dataset.repeat(count),条件表达式用于计算count:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2843 次 |
| 最近记录: |