我可以使用 ktrain 库从检查点恢复训练吗?

MS9*_*S91 2 python keras

ktrain 是深度学习库 TensorFlow Keras(和其他库)的轻量级包装器,可帮助构建、训练和部署神经网络和其他机器学习模型。我可以使用 ktrain 库从检查点恢复训练吗?

blu*_*tax 5

是的你可以。ktrain 常见问题解答中对此进行了解答。我将答案复制在这里:

方法 1:使用 Predictor API(适用于任何模型)

# save model and Preprocessor instance after partially training
ktrain.get_predictor(model, preproc).save('/tmp/my_predictor')

# reload Predictor and extract model
model = ktrain.load_predictor('/tmp/my_predictor').model

# re-instantiate Learner and continue training
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
learner.fit_onecycle(2e-5, 1)
Run Code Online (Sandbox Code Playgroud)

请注意,preproc这里是一个预处理器实例。texts_from_csv如果使用像或 这样的数据加载函数images_from_folder,它将是该函数的第三个返回值。或者,如果使用Transformer API进行文本分类,它将是调用的输出text.Transformer(即preproc = text.Transformer('bert-base-uncased', ...))。

方法2:使用transformers库(如果训练Hugging Face Transformers模型)

如果模型是 Hugging Face 变形金刚模型,则可以transformers直接使用:

# save model using transformers API after partially training
learner.model.save_pretrained('/tmp/my_model')

# reload the model using transformers directly
from transformers import *
model = TFAutoModelForSequenceClassification.from_pretrained('/tmp/my_model')
model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])

# re-instantiate Learner and continue training
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
learner.fit_onecycle(2e-5, 1)
Run Code Online (Sandbox Code Playgroud)

方法3:使用checkpoint_folder参数保存模型权重

参数checkpoint_folder(例如learner.autofit(1e-4, 4, checkpoint_folder='/tmp/saved_weights'))仅在每个时期后保存模型的权重。model.load_weights任何时期的权重都可以使用通常在 中使用的方法重新加载到模型中tf.Keras。您只需要先重新创建模型即可。例如,如果训练 NER 模型,它将按如下方式工作:

# recreate model from scratch
import ktrain
from ktrain import text
model = text.sequence_tagger(...
# load checkpoint weights from 3rd epoch into model
model.load_weights('../models/checkpoints/weights-03.hdf5')
# recreate learner
learner = ktrain.get_learner(model, ...
# continue training here
Run Code Online (Sandbox Code Playgroud)

最后,还有一种learner.save_model方法learner.load_model用于在单个会话期间进行交互训练时保存和重新加载模型。