Uma*_*aja 11 python keras tensorflow tensorflow2.0 gradienttape
我正在使用 Tensorflow DCGAN 实施指南中提供的代码编写自定义训练循环。我想在训练循环中添加回调。在 Keras 中,我知道我们将它们作为参数传递给 'fit' 方法,但找不到有关如何在自定义训练循环中使用这些回调的资源。我正在从 Tensorflow 文档中添加自定义训练循环的代码:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()
    for image_batch in dataset:
      train_step(image_batch)
    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)
    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
Rob*_*all 14
我自己也遇到过这个问题:(1)我想使用自定义训练循环;(2) 我不想失去 Keras 在回调方面提供的附加功能;(3) 我不想自己重新实现它们。Tensorflow 的设计理念是允许开发人员逐渐选择使用其更底层的 API。正如 @HyeonPhilYoun 在下面的评论中指出的那样,官方文档tf.keras.callbacks.Callback给出了我们正在寻找的内容的示例。
以下内容对我有用,但可以通过逆向工程进行改进tf.keras.Model。
诀窍是tf.keras.callbacks.CallbackList在自定义训练循环中使用并手动触发其生命周期事件。此示例用于tqdm提供有吸引力的进度条,但CallbackList有一个progress_bar初始化参数可以让您使用默认值。training_model就是一个典型的例子tf.keras.Model。
from tqdm.notebook import tqdm, trange
# Populate with typical keras callbacks
_callbacks = []
callbacks = tf.keras.callbacks.CallbackList(
    _callbacks, add_history=True, model=training_model)
logs = {}
callbacks.on_train_begin(logs=logs)
# Presentation
epochs = trange(
    max_epochs,
    desc="Epoch",
    unit="Epoch",
    postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}")
epochs.set_postfix(loss=0, accuracy=0)
# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)
for epoch in epochs:
    callbacks.on_epoch_begin(epoch, logs=logs)
    # I like to formulate new batches each epoch
    # if there are data augmentation methods in play
    training_batches = batches(x, Y)
    # Presentation
    enumerated_batches = tqdm(
        enumerate(training_batches),
        desc="Batch",
        unit="batch",
        postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}",
        position=1,
        leave=False)
    for (batch, (x, y)) in enumerated_batches:
        training_model.reset_states()
        
        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_train_batch_begin(batch, logs=logs)
        
        logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)
        callbacks.on_train_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)
        # Presentation
        enumerated_batches.set_postfix(
            loss=float(logs["loss"]),
            accuracy=float(logs["accuracy"]))
    for (batch, (x, y)) in enumerate(test_batches):
        training_model.reset_states()
        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_test_batch_begin(batch, logs=logs)
        logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)
        callbacks.on_test_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)
    # Presentation
    epochs.set_postfix(
        loss=float(logs["loss"]),
        accuracy=float(logs["accuracy"]))
    callbacks.on_epoch_end(epoch, logs=logs)
    # NOTE: This is a decent place to check on your early stopping
    # callback.
    # Example: use training_model.stop_training to check for early stopping
callbacks.on_train_end(logs=logs)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
    if isinstance(cb, tf.keras.callbacks.History):
        history_object = cb
assert history_object is not None
最简单的方法是检查损失是否在预期期间发生了变化,如果没有,则中断或操纵训练过程。这是您可以实现自定义提前停止回调的一种方法:
def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
    #No early stopping for 2*patience epochs 
    if len(LossList)//patience < 2 :
        return False
    #Mean loss for last patience epochs and second-last patience epochs
    mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
    mean_recent = np.mean(LossList[::-1][:patience]) #last
    #you can use relative or absolute change
    delta_abs = np.abs(mean_recent - mean_previous) #abs change
    delta_abs = np.abs(delta_abs / mean_previous)  # relative change
    if delta_abs < min_delta :
        print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
        print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
        return True
    else:
        return False
这会Callback_EarlyStopping在每个时期检查您的指标/损失,True如果相对变化小于您通过计算每个patience时期后损失的移动平均值而预期的值,则返回。然后,您可以捕获此True信号并打破训练循环。要完全回答您的问题,在您的样本训练循环中,您可以将其用作:
gen_loss_seq = []
for epoch in range(epochs):
  #in your example, make sure your train_step returns gen_loss
  gen_loss = train_step(dataset) 
  #ideally, you can have a validation_step and get gen_valid_loss
  gen_loss_seq.append(gen_loss)  
  #check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
  stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
  if stopEarly:
    print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
    print("Terminating training ")
    break
       
当然,您可以通过多种方式增加复杂性,例如,您想要跟踪哪些损失或指标,您对特定时期的损失或损失的移动平均数的兴趣,您对价值的相对或绝对变化的兴趣,等等。你可以参考tf.keras.callbacks.EarlyStopping 这里的Tensorflow 2.x 实现,它通常在流行的tf.keras.Model.fit方法中使用。
小智 2
我认为您需要手动实现回调的功能。应该不会太难。例如,您可以让“train_step”函数返回损失,然后实现回调功能,例如在“train”函数中提前停止。对于诸如学习率计划之类的回调,函数 tf.keras.backend.set_value(generator_optimizer.lr,new_lr) 会派上用场。因此,回调的功能将在您的“train”函数中实现。