use*_*974 6 python tensorflow jax
JAX 的文档说,
并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状是静态的且在编译时已知。
现在我有点惊讶,因为 TensorFlow 具有类似的操作,tf.boolean_mask而 JAX 在编译时似乎无法执行此操作。
tf.boolean_mask一直存在的问题。编辑
梯度通过tf.boolean_mask(显然不是在掩码值上,掩码值是离散的);这里使用 TF1 样式图的例子,其中值未知,因此 TF 不能依赖它们:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape) # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None
Run Code Online (Sandbox Code Playgroud)
目前,您不能(如此处讨论的)
这不是 JAX jit 与 TensorFlow 的限制,而是 XLA 的限制,或者更确切地说是两者编译方式的限制。
JAX 仅使用 XLA 来编译该函数。XLA需要知道静态形状。这是XLA固有的设计选择。
TensorFlow 使用function:这会创建一个可以具有静态未知形状的图。这不如使用 XLA 高效,但仍然很好。但是,tf.function提供了一个选项jit_compile,它将使用 XLA 编译函数内的图形。虽然这通常可以提供不错的加速(免费),但它也有限制:形状需要静态已知(惊喜,惊喜……)
总体而言,这并不是太令人惊讶的行为:计算机中的计算通常会更快(假设有一个像样的优化器对其进行了处理),以前称为可以优化调度的参数(内存布局等)越多。知道的越少,代码就越慢(这方面是普通的 Python)。
| 归档时间: |
|
| 查看次数: |
1968 次 |
| 最近记录: |