使用 tf.contrib.data.parallel_interleave 并行化 tf.from_generator

Nit*_*tin 7 python keras tensorflow tensorflow-datasets

我有一堆 JSON 数组文件(准确地说是 AVRO),每个文件都会产生多个样本来训练 Keras 模型。使用来自@GPhilo@jsimsa 的想法,我能够想出这个来并行化我的输入管道。无法弄清楚如何设计generator(n)来划分处理文件的工作。代码在内部失败,parse_file(f)因为该函数需要一个字符串文件路径而不是一个Tensor,

N = num_cores = 2
files_to_process = ["f1.avro", "f2.avro", "f3.avro"]
shuffle_size = prefetch_buffer = 1000
batch_size = 512

def generator(n):
    size = math.ceil(len(files_to_process) / N)
    start_index = n * size
    end_index = start_index + size

    def gen():
        # for f in files_to_process[start_index:end_index]:
        for f in tf.slice(files_to_process, start_index, size):
            yield f

    return gen

def dataset(n):
    return tf.data.Dataset.from_generator(generator(n), (tf.string,))

def process_file(f):
    examples_x, examples_y = parse_file(f)
    return examples_x, examples_y

ds = tf.data.Dataset.range(N)
ds = ds.apply(tf.contrib.data.parallel_interleave(dataset, cycle_length=N))
ds = ds.map(process_file, num_parallel_calls=N)
ds = ds.prefetch(prefetch_buffer)
ds = ds.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
ds = ds.batch(batch_size).shuffle(shuffle_size)

...
myTfKerasModel.fit( ds.make_one_iterator(), NUM_TRAIN_SAMPLES // batch_size )
Run Code Online (Sandbox Code Playgroud)
  • generator(n)在这里设计的正确方法是什么
  • 这是使用parallel_interleave和设计我的输入管道的优化方法吗flat_map

GPh*_*ilo 7

在我看来,您正在不必要地使用发电机使您的生活复杂化。这就是我实现输入管道的方式:

def parse_file_tf(filename):
    return tf.py_func(parse_file, [filename], [tf.float32, tf.float32])

# version with map
files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.map(parse_file_tf, num_parallel_calls=N)
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.batch(batch_size).shuffle(shuffle_size).prefetch(2)
it = dataset.make_one_shot_iterator()
Run Code Online (Sandbox Code Playgroud)

为了测试它,我将一个虚拟对象定义parse_file为:

i=0
def parse_file(f):
    global i
    i += 1
    return np.asarray([i]*i, dtype=np.float32), np.asarray([i]*i, dtype=np.float32) # mimicks variable-length examples_x, examples_y
Run Code Online (Sandbox Code Playgroud)

我将其输入到显示迭代器返回内容的基本循环中:

sess = tf.Session()
try:
    while True:
        x, y = it.get_next()
        vx, vy = sess.run([x,y])
        print(vx)
        print(vy)
except tf.errors.OutOfRangeError:
    pass
sess.close()
Run Code Online (Sandbox Code Playgroud)

运行上面的代码打印:

[2. 3. 2. 1. 3. 3.]
[2. 3. 2. 1. 3. 3.]
Run Code Online (Sandbox Code Playgroud)

管道说明

本质上,我将并行化问题留给map,在那里我可以传递它应该运行的线程数。不需要生成器迭代范围和那些额外的复杂性。

我选择 map overparallel_interleave因为后者要求您为Dataset它返回的每个项目生成一个实例,在您的情况下这没有意义,因为您在运行时已经在内存中加载了所有值parse_fileparallel_interleave如果您缓慢地生成值(例如,通过应用tf.data.TFRecordDataset到文件名列表)是有意义的,但是如果您的数据集适合内存,请选择map.

关于tf.py_func限制,它们不会影响您经过训练的网络,只会影响输入管道。理想情况下,您的训练和网络的最终使用会有不同的管道。您只需要注意后者的限制,而对于训练(除非您通过分布式训练和/或在机器之间移动训练进行非常具体的操作),您是相当安全的。


带生成器的版本

如果您的 JSON 文件非常大并且它们的内容无法放入内存中,您可以使用生成器,但与您开始使用的方法略有不同。这个想法是,生成器一次遍历 JSON 文件和yield一条记录。然后,生成器必须是您的parse_file功能。例如,假设您有以下parse_file生成器:

i = 3
def parse_file(filename):
    global i
    i += 1
    ctr = 0
    while ctr < i:
        yield ctr, ctr
Run Code Online (Sandbox Code Playgroud)

在这种情况下,管道将如下所示:

def wrap_generator(filename):
    return tf.data.Dataset.from_generator(parse_file(filename), [tf.int32, tf.int32])

files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.apply(tf.contrib.data.parallel_interleave(wrap_generator, cycle_length=N))
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.shuffle(shuffle_size).batch(batch_size).prefetch(2)
it = dataset.make_one_shot_iterator()
Run Code Online (Sandbox Code Playgroud)

请注意,这里我们需要使用,parallel_interleave因为我们将生成器转换Dataset为从中提取值的实例。其余保持不变。

将其送入与上述打印相同的示例循环:

[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]
[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]
Run Code Online (Sandbox Code Playgroud)

  • 对!是的,确实我为 `files.apply(...)` 留下了错误的一行,谢谢你抓住它,我现在在代码中修复它;) 我不知道有任何用于输入管道性能的可视化工具,我知道这是 TF 团队已经工作了一段时间的领域,但这就是我所知道的。此处提供有关如何改进输入管道的提示:https://www.tensorflow.org/performance/datasets_performance (2认同)