小编est*_*ito的帖子

使用@tf.function 进行自定义张量流训练的内存泄漏

我正在尝试TF2/Keras按照官方 Keras 演练为 编写自己的训练循环。vanilla 版本的效果很好,但是当我尝试将@tf.function装饰器添加到我的训练步骤时,一些内存泄漏占用了我所有的内存并且我失去了对机器的控制,有谁知道发生了什么?

代码的重要部分如下所示:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = siamese_network(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, siamese_network.trainable_weights)
    optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = siamese_network(x, training=False)
    val_acc_metric.update_state(y, val_logits)
    val_prec_metric.update_state(y_batch_val, val_logits)
    val_rec_metric.update_state(y_batch_val, val_logits)


for epoch in range(epochs):
        step_time = 0
        epoch_time = time.time()
        print("Start of {} epoch".format(epoch))
        for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
            if step > steps_epoch:
                break
           
            loss_value = train_step(x_batch_train, y_batch_train)
        train_acc …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow custom-training

7
推荐指数
1
解决办法
436
查看次数

标签 统计

custom-training ×1

keras ×1

python ×1

tensorflow ×1