Alb*_*ert 8 python tensorflow tensorflow-datasets horovod
使用 Horovod,您基本上运行 N 个独立实例(因此它是图间复制的一种形式),并且它们通过特殊的 Horovod 操作(基本上是广播 + 减少)进行通信。
现在让我们说实例 0 或其他一些外部实例加载您的数据(通过tf.data.Dataset)。你将如何分发iterator.get_next()到每个实例?使用 Horovod 广播效率低下,因为您会将所有数据复制到所有实例。
在每个实例中都有数据集,并在那里完成所有加载,然后shard在数据集上使用也将是低效的,因为您会在任何地方加载数据,然后丢弃 (N-1)/N 个数据。所以这就是为什么也不想要分片,而是只在单个(生产者/数据集工作人员)实例中加载数据集,然后将批次分发给所有火车工作人员。
我猜 TFMultiDeviceIterator提供了一些类似的功能(或基本上完全相同),但我不确定它是否与 Horovod 一起使用,以及您将如何设置它?
或者,也许您可以通过 TF 工作人员以某种方式进行分发(指南?(也许这也是您的配置MultiDeviceIterator方式?)
如果可能的话,这应该通过 TensorFlow 操作/函数(有许多相关的函数可能已经给了我这个,但我可能不知道它们,或者误解了它是如何工作的)。或者也许答案是 TensorFlow 还没有提供任何这样的功能?(知道这仍然很有用。然后我会用 C++ 实现我自己的解决方案,包装为 TensorFlow。但在此之前,最好知道这是否真的有必要。)
(相关的还有这个 Horovod 问题。)
(这个问题实际上比 Horovod 更通用一些,尽管 Horovod 可能是一个很好的例子。对于分布式 TensorFlow,您可能总是遇到这个问题?)
(我收集了所有的概述分布式TensorFlow术语和方面在这里,多为澄清。)
正如您所说,复制每个实例中的数据并为每个实例分片数据是不切实际的。
一种解决方案是将数据流程中的数据分开,并让每个实例从数据流程中提取数据,如下图所示。例如,可以使用队列来建立这种通信。
在这样的系统中,数据处理将加载数据集,将其预处理为批次并将批次推送到队列中。然后,每个训练实例将从该队列中提取批次。例如,您可以将队列作为生成器传递给数据集 API(请参阅tf.data.Dataset.from_generator)。此外,如果批次的生产速度不够快,则可以创建更多数据处理以增加批次吞吐量。
根据您的用例,实现细节会有所不同。有关更多信息,您可以查找网络和进程间通信以及多处理管道和队列。
Training
+--------------+ ++
| | |
+----+ Instance 1 | |
| | | |
| +--------------+ |
| |
Preprocessing | |
+--------------------+ +----> X |
| | | |
Load | | Batches + X |
Dataset+------> Data Process +--------->Queue | N instances
| | + X | Distributed training
| | | | For example, using
+--------------------+ +----> X | Horovod broadcast + reduce
| |
| Training |
| +--------------+ |
| | | |
+----+ Instance N | |
| | |
+--------------+ ++
Run Code Online (Sandbox Code Playgroud)
对于张量流实现,您可以使用tf.data.Dataset.shardwith tf.data.TFRecordDataset。
该文档解决了您使用 TFRecords 的低效率问题:
重要警告:
在使用任何随机化运算符(例如 shuffle)之前,请务必进行分片。
通常,最好在数据集管道中尽早使用分片运算符。例如,从一组 TFRecord 文件中读取时,在将数据集转换为输入样本之前进行分片。这避免了读取每个工人的每个文件。以下是完整管道中有效分片策略的示例:
Run Code Online (Sandbox Code Playgroud)d = Dataset.list_files(pattern) d = d.shard(num_workers, worker_index) d = d.repeat(num_epochs) d = d.shuffle(shuffle_buffer_size) d = d.interleave(tf.data.TFRecordDataset, cycle_length=num_readers, block_length=1) d = d.map(parser_fn, num_parallel_calls=num_map_threads)
| 归档时间: |
|
| 查看次数: |
986 次 |
| 最近记录: |