从 TensorFlow 中的多个文件中过滤批次

Ric*_*ard 6 python tensorflow tensorflow-datasets tensorflow2.0

我正在从多个文件中读取 TF 示例,并希望在将它们传递给某些批处理操作之前过滤这些示例。

不幸的是,我为此考虑的两种方法似乎都失败了。

接下来是 MWE。

创建数据集并定义几个有用的函数:

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tfv1
tfv1.enable_v2_behavior()

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def serialize_example(feature0, feature1):
  """
  Creates a tf.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.
  feature = {
      'filter_on_this': _float_feature(feature0),
      'sequence':       _float_feature(feature1),
  }

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

for i in range(10):
  # Write the `tf.Example` observations to the file.
  with tf.io.TFRecordWriter("out{0}.tfexamples".format(i)) as writer:
    for i in range(1000):
      example = serialize_example(np.random.random(1), np.random.random(np.random.randint(50,70)))
      writer.write(example)

sequence_numeric_features_spec = {
    'filter_on_this': tf.io.FixedLenFeature((1,), dtype=tf.float32, default_value=0), #tf.io.VarLenFeature(dtype=tf.float32)
    'sequence':       tf.io.VarLenFeature(dtype=tf.float32),
}

DEFAULT_DTYPE_VALUES = {
    tf.float32: 0.,
    tf.string: '',
    tf.int64: 0,
}

def parse_record2(record):
  record = tf.io.parse_example(record, features=sequence_numeric_features_spec)
  for key, feature in record.items():
    if not isinstance(feature, tf.SparseTensor):
      continue
    dense = tf.sparse.to_dense(sp_input=feature, default_value=DEFAULT_DTYPE_VALUES[feature.values.dtype])
    dense = tf.expand_dims(dense, axis=2)
    record[key] = dense
  return record
Run Code Online (Sandbox Code Playgroud)

第一种方法:

files   = tf.data.Dataset.list_files(file_pattern="out*.tfexamples", shuffle=True, seed=123456789)
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100), cycle_length=8, num_parallel_calls=4)
dataset = dataset.batch(10)
dataset = dataset.map(parse_record2, num_parallel_calls=20)
dataset = dataset.filter(lambda x: x['filter_on_this']<0.5)
for x in dataset.take(1):
  pass
Run Code Online (Sandbox Code Playgroud)

这种方法是可取的,因为当独立解析示例时,解码转换parse_record2需要很长时间。批处理示例大大加快了转换速度。不幸的是,这引发了错误:

ValueError:predicate返回类型必须可转换为标量布尔张量。是 TensorSpec(shape=(None, 1), dtype=tf.bool, name=None)。

第二种方法

由于批次似乎有问题,我尝试了一种没有它们的替代方案:

files   = tf.data.Dataset.list_files(file_pattern="out*.tfexamples", shuffle=True, seed=123456789)
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100), cycle_length=8, num_parallel_calls=4)
dataset = dataset.batch(20)
dataset = dataset.map(parse_record2, num_parallel_calls=20)
dataset = dataset.unbatch()
dataset = dataset.filter(lambda x: x['filter_on_this'][0]<0.5)
dataset = dataset.batch(20)
for x in dataset.take(4):
  pass
Run Code Online (Sandbox Code Playgroud)

但是,除了看起来性能较低的方法之外,这还失败了

InvalidArgumentError:无法批量处理组件 1 中具有不同形状的张量。第一个元素的形状为 [69,1],元素 2 的形状为 [66,1]。

以块为单位解码示例并将它们过滤成相同大小的批次的好方法是什么?