Der*_*erk 5 tensorflow tensorflow-datasets
在生产环境中,我有来自N个生产者的数据必须经过网络。我在并行化tf.data.Dataset.from_generator时发现了此注释,它实际上描述了我想要的内容。
def generator(n):
# returns n-th generator function
def dataset(n):
return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# where N is the number of generators you use
Run Code Online (Sandbox Code Playgroud)
但是generator(n)函数应该是什么样子。因为当我用
def generator(n):
"""Returns the n-th generator function (for consumer n)
"""
consumer = self.consumers[n]
def gen():
for item in consumer:
yield item
return gen
Run Code Online (Sandbox Code Playgroud)
与self.consumers一个Python列表,然后我会得到错误:
TypeError:列表索引必须是整数或切片,而不是Tensor
实现几乎是正确的,但是您收到错误,因为n中的参数dataset(n)是“符号” tf.Tensor,而不是可用于在 中查找使用者的实际值self.consumers。
幸运的是,有一个解决方法,其中涉及将n可选args参数传递给tf.data.Dataset.from_generator():
def dataset(n):
return tf.data.Dataset.from_generator(generator, args=(n,))
Run Code Online (Sandbox Code Playgroud)
在幕后,在每次调用 之前from_generator()插入一些代码以转换为 Python 整数。ngenerator
| 归档时间: |
|
| 查看次数: |
267 次 |
| 最近记录: |