交错tf.data.Datasets

Rob*_*obR 9 tensorflow

我正在尝试使用tf.data.Dataset交错两个数据集,但遇到了问题.给出这个简单的例子:

ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()
Run Code Online (Sandbox Code Playgroud)

什么是...产生类似的输出0, 1, 2, 3...9

似乎dataset.interleave()似乎是相关的,但我无法以不产生错误的方式表达语句.

mrr*_*rry 21

MattScarpino在评论中走在正确的轨道上.您可以使用Dataset.zip()Dataset.flat_map()来展平多元素数据集:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
    lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
        tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()
Run Code Online (Sandbox Code Playgroud)

话虽如此,你对使用的直觉Dataset.interleave()是非常明智的.我们正在调查您可以更轻松地完成此任务的方法.


PS.作为替代方案,如果您更改方式和定义,您可以使用它Dataset.interleave()来解决问题:ds0ds1

dataset = tf.data.Dataset.range(2).interleave(
    lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)
Run Code Online (Sandbox Code Playgroud)

  • 使用“interleave”的答案缺少了一些东西,不是吗?它不使用“ds0”和“ds1”。 (4认同)
  • 在 2020 年读到这篇文章,我想“交错”确实是**的**方式,对吗?您能解释一下这些差异吗?例如哪个会更有效率? (2认同)

Pav*_*l K 5

如果您不需要保留要交错的项目的严格顺序,tf.data.experimental.sample_from_datasets方法也很有用。

就我而言,我不得不将现实生活中的数据与一些合成数据交织在一起,因此顺序对我来说不是问题。然后这可以很容易地完成如下

dataset = tf.data.experimental.sample_from_datasets([ds0, ds1])
Run Code Online (Sandbox Code Playgroud)

请注意,结果将是不确定的,某些项目可以从同一数据集提取两次,但通常它与常规交错非常相似。

这种方法的优点:

  • 您可以在一个方法调用中使用多个数据集
  • 您可以使用weights参数为每个数据集指定样本的一部分(例如,我只想生成一小部分数据,所以我使用了weights=[0.9, 0.1]