Tensorflow cifar同步点

blu*_*sky 6 deep-learning tensorflow

阅读https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py函数average_gradients,提供以下注释:Note that this function provides a synchronization point across all towers.函数average_gradients是阻塞调用,是什么意思synchronization point

我假设这是一个阻塞调用,因为为了计算每个梯度必须单独计算的梯度的平均值?但是等待所有个别梯度计算的阻塞代码在哪里?

Blu*_*Sun 6

average_gradients本身不是阻塞函数.它可能是具有张量流操作的另一个函数,这仍然是一个同步点.使它阻塞的原因是它使用了tower_grads依赖于前一个for循环中创建的所有图形的参数.

基本上这里发生的是创建训练图.首先,在for循环for i in xrange(FLAGS.num_gpus)中创建了几个图形"线程".每个看起来像这样:

计算损失 - >计算梯度 - >追加 tower_grads

这些图形"线程"中的with tf.device('/gpu:%d' % i)每一个都被分配给不同的gpu,并且每个图形可以彼此独立地运行(并且稍后将并行运行).现在,下次tower_grads使用时没有设备规范,它会在主设备上创建一个图形延续,将所有这些单独的图形"线程"绑定到一个图形.Tensorflow将确保tower_grads在运行average_gradients函数内部的图形之前完成创建的一部分图形"线程" .因此,稍后sess.run([train_op, loss])调用时,这将是图形的同步点.