如何在tensorflow中使用dataset.shard?

Jia*_*nbo 7 tensorflow tensorflow-datasets

最近我在研究Tensorflow中的数据集API,并且有一种dataset.shard()用于分布式计算的方法.

这就是Tensorflow文档中所述的内容:

Creates a Dataset that includes only 1/num_shards of this dataset.

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
Run Code Online (Sandbox Code Playgroud)

据说该方法返回原始数据集的一部分.如果我有两个工人,我应该这样做:

d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()

for worker_id in workers:
    with tf.device(worker_id):
        if worker_id == 0:
            data = iterator_0.get_next()
        else:
            data = iterator_1.get_next()
        ......
Run Code Online (Sandbox Code Playgroud)

因为文档没有指定如何进行后续调用,所以我在这里有点困惑.

谢谢!

Oli*_*rot 9

您应首先查看Distributed TensorFlow上的教程,以便更好地了解它的工作原理.

您有多个工作人员,每个工作人员运行相同的代码但差别很小:每个工作人员都会有不同的工作人员FLAGS.worker_index.

使用时tf.data.Dataset.shard,您将提供此工作人员索引,数据将在工作人员之间平均分配.

这是一个有3名工人的例子.

dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)


iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()

# Suppose you have 3 workers in total
with tf.Session() as sess:
    for i in range(2):
        print(sess.run(res))
Run Code Online (Sandbox Code Playgroud)

我们将有输出:

  • 0, 3 对工人0
  • 1, 4 对工人1
  • 2, 5 对工人2