加速 Tensorflow 2.0 Gradient Tape

Eri*_*che 2 python keras tensorflow

我一直在关注卷积 VAE 的 TF 2.0 教程,位于此处

由于它是急切的,梯度是手动计算的,然后使用 tf.GradientTape() 手动应用。

for epoch in epochs:
  for x in x_train:
    with tf.GradientTape() as tape:
      loss = compute_loss(model, x)
    apply_gradients(tape.gradient(loss, model.trainable_variables))
Run Code Online (Sandbox Code Playgroud)

该代码的问题在于它非常慢,每个 epoch 大约需要 40-50 秒。如果我将批量大小增加很多(到 2048 左右),那么每个 epoch 最终需要大约 8 秒,但是模型的性能会下降很多。

另一方面,如果我做一个更传统的模型(即,使用基于惰性图的模型而不是热心模型),例如这里的模型,那么即使批量较小,每个 epoch 也需要 8 秒。

model.add_loss(lazy_graph_loss)
model.fit(x_train epochs=epochs)
Run Code Online (Sandbox Code Playgroud)

基于这些信息,我的猜测是 TF2.0 代码的问题在于手动计算损失和梯度。

有什么办法可以加速TF2.0代码,使其更接近正常代码?

Eri*_*che 5

我找到了解决方案:TensorFlow 2.0 引入了函数的概念,将 Eager 代码转换为图形代码。

用法非常简单。唯一需要更改的是所有相关函数(如compute_lossapply_gradients)都必须用 注释@tf.function