有人可以gradient_override_map在TensorFlow中解释我的功能吗?我无法准确理解它的用法.
我看到代码用法为:
with G.gradient_override_map({"Floor": "Identity"}):
return tf.reduce_mean(SomeVals) * SomeOtherVal
Run Code Online (Sandbox Code Playgroud)
到底发生了什么?什么是Identity?
“ Floor”和“ Identity”都是操作的类型字符串,前者对应于tf.floor,而后者对应于tf.identity。因此,我想您的代码的功能是在传递tf.reduce_mean的输出的同时,将tf.identity的反向传播梯度(简称BPG)计算机制替换为图G 中tf.floor操作的BPG计算机制。gradient_override_map到目前为止,在所有应用程序中,op_type_map的键始终与用于在上下文中生成输出的操作的类型字符串相同,这似乎有点不可思议。我的意思是说我更熟悉带有return tf.floor(SomeVals)而不是的场景tf.reduce_mean(SomeVals)。
什么gradient_override_map({op_A_type: op_B_type})是用op_B替换op_A的BPG计算机制,同时保留op_A_type的前向传播计算机制。lahwran的答案显示了gradient_override_map的常见应用。
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
Run Code Online (Sandbox Code Playgroud)
通过
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
Run Code Online (Sandbox Code Playgroud)
装饰器,tf.RegisterGradient("CustomGrad")注册_const_mul_grad(unused_op, grad)为自定义op类型定义的渐变函数-“ CustomGrad”,
而
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
Run Code Online (Sandbox Code Playgroud)
确保字符串类型“ Identity”(tf.identity)的所有操作(在图g中)的输出保持原样,而tf.identity的BPG计算机制由字符串类型“ CustomGrad”的操作的BPG计算机制代替。
聚苯乙烯
op的类型字符串与OpDef.name定义操作的原型字段相对应。要找到操作者OpDef.name,请参考此问题下的MingXing的答案
不需要声明tf.identity操作的名称,因为tf.identity中的arg'name ' 是可选的。
据我所知,gradient_override_map 允许您说“在这种情况下,任何时候您将使用 X 的梯度,而不是使用 Y 的梯度”。这意味着您仍然需要Y 的梯度作为您想要使用的梯度。
这是我在寻找其工作原理时看到的一个示例:
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
Run Code Online (Sandbox Code Playgroud)
引用:https ://stackoverflow.com/a/43948872/1102705
RegisterGradient()允许您注册正在定义的新操作的梯度,从而允许您拥有一个具有所需梯度的操作,然后您可以在梯度覆盖映射中使用该操作。这有点笨拙——你正在定义一个没有前向传递的操作。
我不清楚 name="Identity" 是否真的有必要。
| 归档时间: |
|
| 查看次数: |
5167 次 |
| 最近记录: |