并行化tf.data.Dataset.from_generator

mat*_*cey 25 tensorflow tensorflow-datasets

我有一个非常简单的输入管道,from_generator非常适合......

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
Run Code Online (Sandbox Code Playgroud)

其中complex_img_label_generator动态生成图像,并返回表示一个numpy的阵列(H, W, 3)图像和一个简单的string标签.处理不是我可以表示从文件和tf.image操作中读取的内容.

我的问题是关于如何平衡发电机?我如何让N个这些生成器在自己的线程中运行.

一个想法是使用dataset.mapnum_parallel_calls处理线程; 但是地图在张量上运行......另一个想法是创建多个生成器,每个生成器都有自己的,prefetch并以某种方式加入它们,但我看不出我如何加入N个生成器流?

我可以遵循任何规范的例子吗?

mat*_*cey 21

事实证明,Dataset.map如果我让生成器超级轻量级​​(仅生成元数据),然后将实际的重度照明移动到无状态函数,我可以使用.通过这种方式,我可以.map使用a 来平行重型提升部件py_func.

作品; 但感觉有点笨拙...很高兴能够添加num_parallel_callsfrom_generator:)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
Run Code Online (Sandbox Code Playgroud)

  • 仅供参考,与`tf.py_func()`的并行可能无法自行加速,请参阅[此答案](/sf/answers/3414672551/). (4认同)
  • 自答案以来,TensorFlow已将`num_parallel_calls`添加到`from_generator`吗? (3认同)
  • @mikkola如果不加快速度,还有其他建议吗?谢谢 (2认同)

Chr*_*ker 7

我正在from_indexabletf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448设计

这样做的好处from_indexable是可以并行化,而python生成器则不能并行化。

函数from_indexable使一个tf.data.range,将可索引的索引包装在一个泛型中tf.py_func并调用map。

对于那些现在想要a的人from_indexable,这里是lib代码

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
Run Code Online (Sandbox Code Playgroud)

这是一个示例(注意:from_indexable具有num_parallel_calls argument

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)
Run Code Online (Sandbox Code Playgroud)

更新 2018年6月10日:由于https://github.com/tensorflow/tensorflow/pull/15121被合并,因此代码from_indexable简化为:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
Run Code Online (Sandbox Code Playgroud)

  • 不幸的是,它经不起时间的考验,因为 tf2 没有 contrib,并且 py_func 已被 py_function 取代,而 py_function 没有 output_shapes、args、kwargs、stateful。最后, py_function 的输出返回未知形状,可以在图中使用。 (2认同)

jsi*_*msa 5

将在 中完成的工作限制在generator最低限度并使用 a 并行化昂贵的处理map是明智的。

或者,您可以使用parallel_interleave以下方法“加入”多个生成器:

定义生成器(n):
  # 返回第 n 个生成器函数

定义数据集(n):
  返回 tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset,cycle_lenght=N))

# 其中 N 是您使用的生成器数量

  • 我真的很喜欢这个。但是 generator(n) 应该返回第 n 个生成器,n 在这里是一个张量。如何获得第n个生成器? (2认同)