tf.data.Dataset.padded_batch填充每个功能的方式有所不同

blu*_*ers 7 python tensorflow tensorflow-datasets

我有一个tf.data.Dataset拥有3种不同功能的实例

  • label 这是一个标量
  • sequence_feature 这是一个标量序列
  • seq_of_seqs_feature 这是一个序列序列特征

我试图用来tf.data.Dataset.padded_batch()生成填充数据作为模型输入-我想以不同的方式填充每个功能。

批处理示例:

[{'label': 24,
  'sequence_feature': [1, 2],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66]]}]
Run Code Online (Sandbox Code Playgroud)

预期产量:

[{'label': 24,
  'sequence_feature': [1, 2, 0],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66],
                           0.0, 0.0    ]}]
Run Code Online (Sandbox Code Playgroud)

如您所见,label不应填充功能,sequence_featureseq_of_seqs_feature应该用给定批次中相应的最长条目填充和和。

mrr*_*rry 11

tf.data.Dataset.padded_batch()方法允许您为padded_shapes生成的批次的每个组件(功能)指定。例如,如果您的输入数据集称为ds

padded_ds = ds.padded_batch(
    BATCH_SIZE,
    padded_shapes={
        'label': [],                          # Scalar elements, no padding.
        'sequence_feature': [None],           # Vector elements, padded to longest.
        'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
    })                                        # in each dimension.
Run Code Online (Sandbox Code Playgroud)

请注意,该padded_shapes参数与输入数据集的元素具有相同的结构,因此在这种情况下,它将使用一个字典,该字典的键与您的功能名称相匹配。