Jax、jit 和动态形状:Tensorflow 的回归?

use*_*974 6 python tensorflow jax

JAX 的文档说,

并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状是静态的且在编译时已知。

现在我有点惊讶,因为 TensorFlow 具有类似的操作,tf.boolean_mask而 JAX 在编译时似乎无法执行此操作。

  1. 为什么 Tensorflow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能是错的。我不记得 Tensorflow 在动态形状和函数方面遇到过麻烦,比如tf.boolean_mask一直存在的问题。
  2. 我们可以预期这种差距在未来会缩小吗?如果不是,为什么在 JAX' jit 中无法实现 Tensorflow(以及其他)所支持的功能?

编辑

梯度通过tf.boolean_mask(显然不是在掩码值上,掩码值是离散的);这里使用 TF1 样式图的例子,其中值未知,因此 TF 不能依赖它们:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape)  # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None
Run Code Online (Sandbox Code Playgroud)

May*_*u36 4

目前,您不能如此处讨论的

这不是 JAX jit 与 TensorFlow 的限制,而是 XLA 的限制,或者更确切地说是两者编译方式的限制。

JAX 仅使用 XLA 来编译该函数。XLA需要知道静态形状。这是XLA固有的设计选择。

TensorFlow 使用function:这会创建一个可以具有静态未知形状的图。这不如使用 XLA 高效,但仍然很好。但是,tf.function提供了一个选项jit_compile,它将使用 XLA 编译函数内的图形。虽然这通常可以提供不错的加速(免费),但它也有限制:形状需要静态已知(惊喜,惊喜……)

总体而言,这并不是太令人惊讶的行为:计算机中的计算通常会更快(假设有一个像样的优化器对其进行了处理),以前称为可以优化调度的参数(内存布局等)越多。知道的越少,代码就越慢(这方面是普通的 Python)。