tf.Data:什么是并行交错的落后者?

mik*_*ola 4 python tensorflow tensorflow-datasets

interleave是一种tf.Data.Dataset可用于将来自多个数据集的元素交错在一起的方法.tf.contrib.data.parallel_interleave在...的帮助下提供相同功能的并行版本apply.

我可以看到,并行版本允许并行读取许多数据集并为它们提供缓冲区将提高吞吐量.但文档还说明了如何parallel_interleave提高数据吞吐量:

与tf.data.Dataset.interleave不同,它从并行的cycle_length嵌套数据集中获取元素,这增加了吞吐量,尤其是在存在落后者的情况下.

什么是落后者,为什么parallel_interleave在他们的存在方面的吞吐量特别好?

Cip*_*agă 7

落后者是一种比平时更长的功能来产生其输出.这可能是由于网络拥塞或随机性的奇怪组合.

interleave在单个线程上以顺序方式完成所有处理.在下面的模式中,让我们___表示等待IO/Computation,<waiting>表示等待轮到一个元素111表示生成第一个元素(1).

假设我们有一个目录数据集,ds = [A, B, C, D]我们1,2,3...从每个目录生成文件.然后使用r = ds.interleave(cycle_length=3, block_length=2)将工作类似这样:

A: ___111___222
B:   <waiting> ___111___________222
C:   <waiting> <waiting> <waiting> ___111___222

R: ____A1____A2____B1____________B2____C1____C2
Run Code Online (Sandbox Code Playgroud)

您会看到,如果从B散布中生成元素,则所有后续元素都必须等待处理.

parallel_interleave有两种方式帮助落后者.首先,它并行地启动循环中的每个元素(因此名称).因此,生产架构变为:

A: ___111___222
B: ___<waiting>111___________222
C: ___<waiting><waiting><waitin>111___222

R: ____A1____A2_B1____________B2_C1____C2|....|
Run Code Online (Sandbox Code Playgroud)

这样做有助于通过并行等待减少无用的等待.该部分|....|显示了与顺序版本相比节省了多少.

它帮助的第二种方式是允许sloppy参数.如果我们将其设置为True,则允许跳过不可用的元素,直到它可用为止,代价是产生非确定性的顺序.这是如何做:

A: ___111___<w>222
B: ___<w>111___________222
C: ___<w><w>111___222

R: ____A1_B1_C1_A2_C2___B2|...................|
Run Code Online (Sandbox Code Playgroud)

看看节省!! 而且还要看元素的顺序!


我在代码中重现这些.这是一种丑陋的方式,但它有点说明了差异.

from time import sleep
DS = tf.data.Dataset

def repeater(val):
    def _slow_gen():
        for i in range(5):
            if i % 2:
                sleep(1)
            yield i
    return DS.from_generator(_slow_gen, tf.int8)

ds = DS.range(5)

slow_ds = ds.interleave(repeater, cycle_length=2, block_length=3)

para_ds = ds.apply(tf.contrib.data.parallel_interleave(
    repeater, cycle_length=2, block_length=3)
)

sloppy_ds = ds.apply(tf.contrib.data.parallel_interleave(
    repeater, cycle_length=2, block_length=3, sloppy=True)
)


%time apply_python_func(slow_ds, print, sess)
# 10 sec, you see it waiting each time

%time apply_python_func(para_ds, print, sess) 
#  3 sec always! you see it burping a lot after the first wait

%time apply_python_func(sloppy_ds, print, sess) 
# sometimes 3, sometimes 4 seconds
Run Code Online (Sandbox Code Playgroud)

这是显示数据集的功能

def apply_python_func(ds, func, sess):
    """Exact values from ds using sess and apply func on them"""
    it = ds.make_one_shot_iterator()
    next_value = it.get_next()
    num_examples = 0
    while True:
        try:
            value = sess.run(next_value)
            num_examples += 1
            func(value)
        except tf.errors.OutOfRangeError:
            break
    print('Evaluated {} examples'.format(num_examples)) 
Run Code Online (Sandbox Code Playgroud)