Tensorflow数据集API:使用parallel_interleave并行化tf.data.Dataset.from_generator

Der*_*erk 5 tensorflow tensorflow-datasets

在生产环境中,我有来自N个生产者的数据必须经过网络。我在并行化tf.data.Dataset.from_generator时发现了此注释,它实际上描述了我想要的内容。

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use
Run Code Online (Sandbox Code Playgroud)

但是generator(n)函数应该是什么样子。因为当我用

 def generator(n):
        """Returns the n-th generator function (for consumer n)
        """
        consumer = self.consumers[n]

        def gen():
            for item in consumer:
                yield item

        return gen
Run Code Online (Sandbox Code Playgroud)

与self.consumers一个Python列表,然后我会得到错误:

TypeError:列表索引必须是整数或切片,而不是Tensor

mrr*_*rry 0

实现几乎是正确的,但是您收到错误,因为n中的参数dataset(n)是“符号” tf.Tensor,而不是可用于在 中查找使用者的实际值self.consumers

幸运的是,有一个解决方法,其中涉及将n可选args参数传递给tf.data.Dataset.from_generator()

def dataset(n):
  return tf.data.Dataset.from_generator(generator, args=(n,))
Run Code Online (Sandbox Code Playgroud)

在幕后,在每次调用 之前from_generator()插入一些代码以转换为 Python 整数。ngenerator