Tensorflow的gradient_override_map函数

mac*_*c_i 7 python tensorflow

有人可以gradient_override_map在TensorFlow中解释我的功能吗?我无法准确理解它的用法.

我看到代码用法为:

with G.gradient_override_map({"Floor": "Identity"}):
    return tf.reduce_mean(SomeVals) * SomeOtherVal
Run Code Online (Sandbox Code Playgroud)

到底发生了什么?什么是Identity

Yil*_* He 6

“ Floor”和“ Identity”都是操作的类型字符串,前者对应于tf.floor,而后者对应于tf.identity因此,我想您的代码的功能是在传递tf.reduce_mean的输出的同时,将tf.identity的反向传播梯度(简称BPG)计算机制替换为图G 中tf.floor操作的BPG计算机制gradient_override_map到目前为止,在所有应用程序中,op_type_map的键始终与用于在上下文中生成输出的操作的类型字符串相同,这似乎有点不可思议。我的意思是说我更熟悉带有return tf.floor(SomeVals)而不是的场景tf.reduce_mean(SomeVals)

什么gradient_override_map({op_A_type: op_B_type})是用op_B替换op_A的BPG计算机制,同时保留op_A_type的前向传播计算机制。lahwran的答案显示了gradient_override_map的常见应用。

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity")
Run Code Online (Sandbox Code Playgroud)

通过

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad
Run Code Online (Sandbox Code Playgroud)

装饰器,tf.RegisterGradient("CustomGrad")注册_const_mul_grad(unused_op, grad)为自定义op类型定义的渐变函数-“ CustomGrad”,

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity") 
Run Code Online (Sandbox Code Playgroud)

确保字符串类型“ Identity”(tf.identity)的所有操作(在图g中)的输出保持原样,而tf.identity的BPG计算机制由字符串类型“ CustomGrad”的操作的BPG计算机制代替。

聚苯乙烯

  1. op的类型字符串与OpDef.name定义操作的原型字段相对应。要找到操作者OpDef.name,请参考此问题下的MingXing的答案

  2. 不需要声明tf.identity操作的名称,因为tf.identity中的arg'name ' 是可选的。


lah*_*ran 2

据我所知,gradient_override_map 允许您说“在这种情况下,任何时候您将使用 X 的梯度,而不是使用 Y 的梯度”。这意味着您仍然需要Y 的梯度作为您想要使用的梯度。

这是我在寻找其工作原理时看到的一个示例:

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity")
Run Code Online (Sandbox Code Playgroud)

引用:https ://stackoverflow.com/a/43948872/1102705

RegisterGradient()允许您注册正在定义的新操作的梯度,从而允许您拥有一个具有所需梯度的操作,然后您可以在梯度覆盖映射中使用该操作。这有点笨拙——你正在定义一个没有前向传递的操作。

我不清楚 name="Identity" 是否真的有必要。