使用@tf.function 进行自定义张量流训练的内存泄漏

est*_*ito 7 python keras tensorflow custom-training

我正在尝试TF2/Keras按照官方 Keras 演练为 编写自己的训练循环。vanilla 版本的效果很好,但是当我尝试将@tf.function装饰器添加到我的训练步骤时,一些内存泄漏占用了我所有的内存并且我失去了对机器的控制,有谁知道发生了什么?

代码的重要部分如下所示:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = siamese_network(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, siamese_network.trainable_weights)
    optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = siamese_network(x, training=False)
    val_acc_metric.update_state(y, val_logits)
    val_prec_metric.update_state(y_batch_val, val_logits)
    val_rec_metric.update_state(y_batch_val, val_logits)


for epoch in range(epochs):
        step_time = 0
        epoch_time = time.time()
        print("Start of {} epoch".format(epoch))
        for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
            if step > steps_epoch:
                break
           
            loss_value = train_step(x_batch_train, y_batch_train)
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        
        for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
            if val_step>validation_steps:
                break
            test_step(x_batch_val, y_batch_val)
         
        val_acc = val_acc_metric.result()
        val_prec = val_prec_metric.result()
        val_rec = val_rec_metric.result()

        val_acc_metric.reset_states()
        val_prec_metric.reset_states()
        val_rec_metric.reset_states()
Run Code Online (Sandbox Code Playgroud)

如果我对这些@tf.function行发表评论,则不会发生内存泄漏,但步骤时间要慢 3 倍。我的猜测是以某种方式在每个时代或类似的东西中再次创建了该图,但我不知道如何解决它。

这是我正在关注的教程:https : //keras.io/guides/writing_a_training_loop_from_scratch/

xdh*_*ore 3

太长;博士;

TensorFlow 可能会为传递到修饰函数中的每组唯一的参数值生成一个新图。确保将形状一致的Tensor对象传递给python 对象test_step,而train_step不是传递给 python 对象。

细节

这是暗中刺伤。虽然我从未尝试过,但我确实在文档@tf.function中发现了以下警告:

tf.function 还将任何纯 Python 值视为不透明对象,并为其遇到的每组 Python 参数构建单独的图。

注意:将 python 标量或列表作为参数传递给 tf.function 将始终构建一个新图。为了避免这种情况,请尽可能将数字参数作为张量传递

最后:

Function 通过根据输入的 args 和 kwargs 计算缓存键来确定是否重用跟踪的 ConcreteFunction。缓存键是根据 Function 调用的输入 args 和 kwargs 标识 ConcreteFunction 的键,根据以下规则(可能会更改):

  • 为 tf.Tensor 生成的关键是它的形状和数据类型。
  • 为 tf.Variable 生成的键是唯一的变量 id。
  • 为 Python 原语(如 int、float、str)生成的键是它的值。
  • 为嵌套字典、列表、元组、namedtuples 和 attrs 生成的键是叶键的扁平化元组(请参阅 Nest.flatten)。(由于这种扁平化,调用具有与跟踪期间使用的嵌套结构不同的嵌套结构的具体函数将导致类型错误)。
  • 对于所有其他 Python 类型,该键对于该对象来说是唯一的。通过这种方式,可以独立地跟踪调用函数或方法的每个实例。

我从这一切中得到的是,如果您不将大小一致的 Tensor 对象传递给您的@tf.function-ified 函数(也许您使用 Python 集合或基元代替),那么您很可能正在创建一个新的图形版本函数具有您传入的每个不同的参数值。我猜这可能会造成您所看到的内存爆炸行为。我无法告诉您的test_dstrain_ds对象是如何创建的,但您可能希望确保它们的创建方式能够像enumerate(blah_ds)教程中那样返回张量,或者至少在传递给您的test_steptrain_step函数之前将值转换为张量。