如何使用 TF 2.0 tf.recompute_grad?

Yan*_*kin 6 keras tensorflow

我想使用内存节省梯度(openai/gradient-checkpointing)来减少我的神经网络的 GPU 内存成本,但我发现这在 TF 2.0 中是不可能的,但我也发现我可以使用 tf.recompute_grad 来达到此目的。我在Google上没有找到任何示例或教程,所以我在这里询问。另外,是否可以将其与 tf.keras 一起使用?