如何在 Tensorflow 2.0 中使用 gradient_override_map?

Ion*_*ons 5 python tensorflow tensorflow2.0

我正在尝试gradient_override_map与 Tensorflow 2.0一起使用。文档中有一个示例,我也将在此处用作示例。

在 2.0 中,GradientTape可用于计算梯度如下:

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha0

x = tf.Variable(5.0)
with tf.GradientTape() as tape:
    s_1 = tf.square(x)
print(tape.gradient(s_1, x))
Run Code Online (Sandbox Code Playgroud)

还有tf.custom_gradient装饰器,可用于定义函数的渐变(再次使用文档中示例):

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha

@tf.custom_gradient
def log1pexp(x):
    e = tf.exp(x)

    def grad(dy):
        return dy * (1 - 1 / (1 + e))

    return tf.math.log(1 + e), grad

x = tf.Variable(100.)

with tf.GradientTape() as tape:
    y = log1pexp(x)

print(tape.gradient(y, x))
Run Code Online (Sandbox Code Playgroud)

但是,我想替换标准函数的渐变,例如tf.square. 我尝试使用以下代码:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

with tf.Graph().as_default() as g:
    x = tf.Variable(5.0)
    with g.gradient_override_map({"Square": "CustomSquare"}):
        with tf.GradientTape() as tape:
            s_2 = tf.square(x, name="Square")

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())            
        print(sess.run(tape.gradient(s_2, x)))
Run Code Online (Sandbox Code Playgroud)

但是,有两个问题:梯度替换似乎不起作用(它被评估为10.0而不是0.0),我需要求助于session.run()执行图形。有没有办法在“原生”TensorFlow 2.0 中实现这一点?

在 TensorFlow 1.12.0 中,以下生成所需的输出:

import tensorflow as tf
print(tf.__version__)  # 1.12.0

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

x = tf.Variable(5.0)

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s_2 = tf.square(x, name="Square")
grad = tf.gradients(s_2, x)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(grad))
Run Code Online (Sandbox Code Playgroud)

mrr*_*rry 7

TensorFlow 2.0 中没有内置机制来覆盖范围内内置运算符的所有梯度。但是,如果您能够为对内置运算符的每次调用修改调用站点,则可以tf.custom_gradient按如下方式使用装饰器:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(0.0)
  return tf.square(x), grad

with tf.Graph().as_default() as g:
  x = tf.Variable(5.0)
  with tf.GradientTape() as tape:
    s_2 = custom_square(x)

  with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())            
    print(sess.run(tape.gradient(s_2, x)))
Run Code Online (Sandbox Code Playgroud)

  • `tf.compat.v1` 兼容性模块包含最新版本 TF 1.x 中 `tf` 模块的所有内容(除了 `tf.contrib`)。在可预见的未来,没有计划从 TensorFlow 中删除它,因为许多库仍然依赖它,尽管新功能开发将集中在主模块上,并且新旧 API 之间的兼容性可能存在差距(尽管,幸运的是,这个案例有效!)。 (2认同)