张量流中每个示例的未聚集梯度/渐变

Bas*_*Bas 15 tensorflow

给定张量流中mnist的简单小批量梯度下降问题(如本教程中所述),如何单独检索批处理中每个示例的渐变.

tf.gradients()似乎返回批次中所有示例的平均梯度.有没有办法在聚合之前检索渐变?

编辑:这个答案的第一步是弄清楚张力流在哪个点上平均了批次中的例子的梯度.我以为这发生在_AggregatedGrads中,但事实并非如此.有任何想法吗?

Yar*_*tov 7

tf.gradients返回与损失有关的梯度.这意味着如果您的损失是每个示例损失的总和,那么梯度也是每个示例损失梯度的总和.

总结是隐含的.例如,如果你希望尽量减少的平方规范的总和Wx-y错误,相对于梯度W2(WX-Y)X'这里X是一批意见,并Y是该批次的标签.你从来没有明确地形成你后面总结的"每个例子"渐变,所以在梯度管道中移除一些阶段并不是一件简单的事情.

获得k每个示例损失梯度的一种简单方法是使用大小为1的批次并进行k传递.Ian Goodfellow 编写了如何k在一次通过中获取所有渐变,为此您需要明确指定渐变而不依赖于tf.gradients方法

  • 您可以使用tf.gradients完成大部分工作.假设您希望每个示例渐变相对于X.您在X的使用者身上调用tf.gradients.例如,您有一个X的变量Z乘以某个矩阵W.然后您需要自己的逻辑来执行 - 通过矩阵乘法进行示例微分,但您可以使用tf.gradients获得与Z相关的导数. (5认同)