Moh*_*han 0 python keras tensorflow
在 Python 中,我从 Keras 转向Model.fit循环Model.train_on_batch以实现更精细的控制。但返回的进度条和History对象fit很有用。在浪费时间从头开始实现它们之前,我想知道是否有人找到了使用train_on_batch重现进度条和历史记录的示例代码?
(注意。我查看了 的源代码fit,但是有足够多的间接层,因此很难准确地挖掘出它在做什么。还发现了this,这很有帮助,但没有相关功能。)
定义和验证数据后EPOCHS,train_generator您val_x, val_y可以替换
history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)
Run Code Online (Sandbox Code Playgroud)
使用以下代码:
callbacks = tf.keras.callbacks.CallbackList(
None,
add_history = True,
add_progbar = True,
model = model,
epochs = EPOCHS,
verbose = 1,
steps = len(train_generator)
)
callbacks.on_train_begin()
for epoch in range(EPOCHS):
model.reset_metrics()
callbacks.on_epoch_begin(epoch)
for i in range(len(train_generator)):
callbacks.on_train_batch_begin(i)
logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)
callbacks.on_train_batch_end(i, logs)
validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
logs.update({'val_' + name: v for name, v in validation_logs.items()})
callbacks.on_epoch_end(epoch, logs)
train_generator.on_epoch_end()
callbacks.on_train_end(epoch_logs)
history = model.history
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1775 次 |
| 最近记录: |