Keras:是否有 train_on_batch 的示例代码,其中包含历史记录+进度?

Moh*_*han 0 python keras tensorflow

在 Python 中,我从 Keras 转向Model.fit循环Model.train_on_batch以实现更精细的控制。但返回的进度条和History对象fit很有用。在浪费时间从头开始实现它们之前,我想知道是否有人找到了使用train_on_batch重现进度条和历史记录的示例代码?

(注意。我查看了 的源代码fit,但是有足够多的间接层,因此很难准确地挖掘出它在做什么。还发现了this,这很有帮助,但没有相关功能。)

Moh*_*han 6

定义和验证数据后EPOCHStrain_generatorval_x, val_y可以替换

history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)
Run Code Online (Sandbox Code Playgroud)

使用以下代码:

callbacks = tf.keras.callbacks.CallbackList(
    None, 
    add_history = True,
    add_progbar = True,
    model = model,
    epochs = EPOCHS,
    verbose = 1,
    steps = len(train_generator)
)

callbacks.on_train_begin()
for epoch in range(EPOCHS):
    model.reset_metrics()
    callbacks.on_epoch_begin(epoch)
    for i in range(len(train_generator)):
        callbacks.on_train_batch_begin(i)
        logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)              
        callbacks.on_train_batch_end(i, logs)

    validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
    logs.update({'val_' + name: v for name, v in validation_logs.items()})

    callbacks.on_epoch_end(epoch, logs)
    train_generator.on_epoch_end()

callbacks.on_train_end(epoch_logs)
history = model.history
Run Code Online (Sandbox Code Playgroud)