如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?

Min*_*ark 25 python tensorflow2.0

我正在尝试使用 TensorFlow 2.0 创建一个数据集,该数据集将从时间序列中返回随机窗口,以及作为目标的下一个值。

我正在使用Dataset.window(),看起来很有希望:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
for window in dataset:
    print([elem.numpy() for elem in window])
Run Code Online (Sandbox Code Playgroud)

输出:

[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]
Run Code Online (Sandbox Code Playgroud)

但是,我想使用最后一个值作为目标。如果每个窗口都是张量,我会使用:

[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]
Run Code Online (Sandbox Code Playgroud)

但是,如果我尝试这样做,则会出现异常:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
Run Code Online (Sandbox Code Playgroud)

Min*_*ark 30

解决方案是这样调用flat_map()

dataset = dataset.flat_map(lambda window: window.batch(5))
Run Code Online (Sandbox Code Playgroud)

现在数据集中的每个项目都是一个窗口,因此您可以像这样拆分它:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
Run Code Online (Sandbox Code Playgroud)

所以完整的代码是:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(5))
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

for X, y in dataset:
    print("Input:", X.numpy(), "Target:", y.numpy())
Run Code Online (Sandbox Code Playgroud)

哪些输出:

Input: [0 1 2 3] Target: [4]
Input: [1 2 3 4] Target: [5]
Input: [2 3 4 5] Target: [6]
Input: [3 4 5 6] Target: [7]
Input: [4 5 6 7] Target: [8]
Input: [5 6 7 8] Target: [9]
Run Code Online (Sandbox Code Playgroud)

  • window() 方法返回包含窗口的数据集,其中每个窗口本身表示为数据集。类似于 {{1,2,3,4,5},{6,7,8,9,10},...},其中 {...} 表示数据集。但我们只想要一个包含张量的常规数据集:{[1,2,3,4,5],[6,7,8,9,10],...},其中[...]代表张量。在转换每个嵌套数据集后,flat_map() 方法返回嵌套数据集中的所有张量。如果我们不进行批处理,我们将得到:{1,2,3,4,5,6,7,8,9,10,...}。通过将每个窗口批量处理到其完整大小,我们可以得到我们想要的 {[1,2,3,4,5],[6,7,8,9,10],...}。清除? (16认同)
  • 尽管回答这个问题不是必需的,但您能否详细说明为什么我们需要这个 flat_map 步骤?我仍然很难理解它。 (2认同)