Keras - 没有停止和恢复训练的好方法吗?

Dan*_*iel 6 python keras tensorflow tf.keras tensorflow2.0

经过大量研究,似乎没有好的方法可以使用 Tensorflow 2 / Keras 模型正确停止和恢复训练。无论您是使用model.fit() 还是使用自定义训练循环,都是如此。

似乎有 2 种支持的方法可以在训练时保存模型:

  1. 保存模型的只是权重,使用model.save_weights()save_weights_only=Truetf.keras.callbacks.ModelCheckpoint。这似乎是我见过的大多数示例的首选,但是它有许多主要问题:

    • 优化器状态未保存,这意味着训练恢复将不正确。
    • 学习率计划被重置——这对于某些模型来说可能是灾难性的。
    • Tensorboard 日志返回到第 0 步 - 除非实施复杂的解决方法,否则日志记录基本上毫无用处。
  2. 使用model.save()或保存整个模型、优化器等save_weights_only=False。优化器状态已保存(良好),但仍存在以下问题:

    • Tensorboard 日志仍然回到步骤 0
    • 学习率计划仍在重置(!!!)
    • 无法使用自定义指标。
    • 这在使用自定义训练循环时根本不起作用 - 自定义训练循环使用非编译模型,并且似乎不支持保存/加载非编译模型。

我发现的最佳解决方法是使用自定义训练循环,手动保存步骤。这修复了 tensorboard 日志记录,并且可以通过执行类似的操作来修复学习率计划keras.backend.set_value(model.optimizer.iterations, step)。但是,由于完整的模型保存不在表中,因此不会保留优化器状态。我看不出有什么方法可以独立保存优化器的状态,至少不需要做很多工作。像我一样搞乱 LR 时间表也感觉很混乱。

我错过了什么吗?人们如何使用此 API 保存/恢复?

Ove*_*gon 5

你说得对,没有内置的可恢复性支持——这正是我创建DeepTrain 的动机。这就像 TensorFlow/Keras 的 Pytorch Lightning(在不同方面越来越好)。

为什么是另一个图书馆?我们还不够吗?你没有这样的东西;如果有,我不会建造它。DeepTrain 专为“保姆式”训练而量身定制:训练较少的模型,但对其进行彻底的训练。密切监视每个阶段,以诊断出了什么问题以及如何解决。

灵感来自我自己的使用;我会在很长的 epoch 中看到“验证尖峰”,并且无法暂停,因为它会重新启动 epoch 或以其他方式破坏火车循环。忘记知道您正在安装哪个批次,或者还剩下多少。

与 Pytorch 闪电相比如何?卓越的可恢复性和内省性,以及独特的列车调试实用程序 - 但 Lightning 在其他方面表现更好。我在工作中有一个综合列表比较,将在一周内发布。

Pytorch 支持来了?也许。如果我说服 Lightning 开发团队弥补其相对于 DeepTrain 的缺点,那么不会——否则很可能。同时,您可以浏览示例库。


最小的例子

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()
Run Code Online (Sandbox Code Playgroud)

您可以KeyboardInterrupt随时检查模型、训练状态、数据生成器 - 并恢复。


小智 5

tf.keras.callbacks.experimental.BackupAndRestore已添加用于从中断中恢复训练的 API tensorflow>=2.3。根据我的经验,它效果很好。

参考: https: //www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore