Tensorflow:连接多个 tf.Dataset 非常慢

Smo*_*row 5 python tensorflow tensorflow-datasets

我在 Tensorflow 1.10

现在我不确定这是否是一个错误。

我一直在尝试连接我从多个 tf.data.Dataset.from_generator 生成的大约 100 个数据集。

for i in range(1, 100):
        dataset = dataset.concatenate(
            tf.data.Dataset.from_generator(gens[i], (tf.int8, tf.int32), output_shapes=(
                (256, 256), (1))))
        print(i)
 print("before iterator")
 iterator = dataset.make_one_shot_iterator()
 print("after iterator")
Run Code Online (Sandbox Code Playgroud)

运行make_one_shot_iterator()需要很长时间。

有谁知道修复吗?

编辑:

看起来 _make_dataset.add_to_graph(ops.get_default_graph()) 似乎一次又一次地被调用,导致该函数被调用了几百万次。(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py 函数 make_one_shot_iterator 第 162 行)

Smo*_*row 0

对于像这样的多个张量或生成器来说,运行concatenate实际上并不是最好的选择。

更好的方法是使用flat_map https://www.tensorflow.org/api_docs/python/tf/data/Dataset#flat_map。我确实更新了示例一段时间,以展示如何将其用于多个张量或文件。