tf.data:并行化加载步骤

Flo*_*ker 8 python tensorflow tensorflow-datasets tensorflow2.0

我有一个数据输入管道,它具有:

  • 不能强制转换为 a tf.Tensor(dicts 和诸如此类)的类型的输入数据点
  • 无法理解 tensorflow 类型并需要处理这些数据点的预处理函数;其中一些进行动态数据增强

我一直在尝试将其放入tf.data管道中,并且一直坚持并行运行多个数据点的预处理。到目前为止,我已经尝试过这个:

  • Dataset.from_generator(gen)在生成器中使用并进行预处理;这是有效的,但它会按顺序处理每个数据点,无论我修补它的prefetchmap调用的排列如何。并行预取是不可能的吗?
  • 将预处理封装在 a 中,tf.py_function以便我可以map在我的数据集上并行处理,但是
    1. 这需要一些非常丑陋的(反)序列化才能将奇异类型放入字符串张量中,
    2. 显然, 的执行py_function将移交给(单进程)python 解释器,所以我会坚持使用 python GIL,这对我没有多大帮助
  • 我看到你可以做一些技巧,interleave但没有发现任何没有前两个想法的问题。

我在这里错过了什么吗?我是否被迫修改我的预处理以便它可以在图形中运行,或者有没有办法对其进行多处理?

我们以前这样做的方法是使用 keras.Sequence ,它运行良好,但推动tf.dataAPI升级的人太多了。(地狱,甚至尝试 keras.Sequence with tf 2.2 yields WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.

注意:我使用的是 tf 2.2rc3

小智 3

我遇到了同样的问题并找到了一个(相对)简单的解决方案。

事实证明,正确的方法确实是首先使用tf.data.Dataset该方法创建一个对象from_generator(gen),然后再使用该方法应用自定义 python 处理函数(包装在 a 中py_functionmap。正如您所提到的,有一个技巧可以避免输入的序列化/反序列化。

诀窍是使用一个生成器,它只会生成训练集的索引。每个调用的训练索引都将传递给包装的 py_function,该函数可以反过来评估该索引处的原始数据集。然后,您可以处理数据点并将处理后的数据返回到管道的其余部分tf.data

def func(i):
    i = i.numpy() # decoding from the EagerTensor object
    x, y = processing_function(training_set[i])
    return x, y # numpy arrays of types uint8, float32

z = list(range(len(training_set))) # the index generator

dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)

dataset = dataset.map(lambda i: tf.py_function(func=func, inp=[i], 
                                               Tout=[tf.uint8, tf.float32]), 
                      num_parallel_calls=12)

dataset = dataset.batch(1)
Run Code Online (Sandbox Code Playgroud)

请注意,在实践中,根据您训练数据集的模型,您可能需要在map以下操作之后将另一个模型应用于您的数据集batch

def _fixup_shape(x, y):
    x.set_shape([None, None, None, nb_channels])
    y.set_shape([None, nb_classes])
    return x, y
dataset = dataset.map(_fixup_shape)
Run Code Online (Sandbox Code Playgroud)

这是一个已知问题,似乎是由于该from_generator方法在某些情况下无法正确推断形状所致。因此,您需要显式传递预期的输出形状。了解更多信息: