等式比较在 TensorFlow 2.0 tf.function() 中不起作用

cs9*_*s95 5 python tensorflow tensorflow2.0

在讨论TensorFlow 2.0 AutoGraphs 之后,我一直在玩,并注意到不等式比较如><是直接指定的,而等式比较使用tf.equal.

这里有一个例子来演示。此函数使用>运算符并在调用时运行良好

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>
Run Code Online (Sandbox Code Playgroud)

这是另一个使用相等比较的函数,但不起作用

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?
Run Code Online (Sandbox Code Playgroud)

如果我将==相等比较更改为tf.equal,它将起作用。

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>
Run Code Online (Sandbox Code Playgroud)

我的问题是:为什么在tf.function函数内部使用不等式比较运算符可以工作,而等式比较则不行?

nes*_*uno 5

我在“分析 tf.function 以发现 Autograph 的优势和微妙之处”一文的第 3 部分中分析了这种行为(我强烈建议阅读所有 3 部分以了解如何在使用tf.function- 底部的链接装饰函数之前正确编写函数)答案)。

对于__eq__tf.equal问题,答案是:

简而言之:__eq__运算符 (for tf.Tensor) 已被覆盖,但该运算符不tf.equal用于检查 Tensor 是否相等,它只是检查 Python 变量标识(如果您熟悉 Java 编程语言,这就像 = = 用于字符串对象的运算符)。原因是该tf.Tensor对象需要是可散列的,因为它在 Tensorflow 代码库中的任何地方都被用作 dict 对象的键。

而对于所有其他运算符,答案是 AutoGraph 不会将 Python 运算符转换为 TensorFlow 逻辑运算符。在AutoGraph (don't) 如何转换运算符部分中,我展示了每个 Python 运算符都被转换为始终被评估为 false 的图形表示。

事实上,下面的例子产生输出“wat”

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)
Run Code Online (Sandbox Code Playgroud)

在实践中,AutoGraph 无法将 Python 代码转换为图形代码;我们必须仅使用 TensorFlow 原语来帮助它。在这种情况下,您的代码将按预期工作。

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
Run Code Online (Sandbox Code Playgroud)

我把这三篇文章的链接放在这里,我想你会发现它们很有用:

第1部分第2部分第3部分

  • Money line:“实际上,AutoGraph 无法将 Python 代码转换为图形代码;我们必须仅使用 TensorFlow 原语来帮助它。” 也会阅读这些文章,看起来内容很丰富。谢谢! (2认同)
  • @cs95 我通过解决问题来学习 TensorFlow。当我还是一名研究员时,我必须自学如何进行对象分类和定位。那时(2015年底)tf是个新东西,我花了很多时间学习它,始终以解决该任务为目标。然后,一旦我对高级 API 变得更有信心(并且我解决了我的研究问题),我就开始回答 SO,以检查我对框架结构的理解是否正确 - 这很有帮助,我强烈推荐正在做。总之,有目标,用tf解决,学习 (2认同)