AKS*_*HAN 5 python validation keras
如何使用 keras 函数 fit_generator() 来训练并同时保存具有最低验证损失的模型权重?
您可以在定义检查点时设置save_best_only=True:
from keras.callbacks import EarlyStopping, ModelCheckpoint
early_stop = EarlyStopping(
monitor='loss',
min_delta=0.001,
patience=3,
mode='min',
verbose=1
)
checkpoint = ModelCheckpoint(
'model_best_weights.h5',
monitor='loss',
verbose=1,
save_best_only=True,
mode='min',
period=1
)
Run Code Online (Sandbox Code Playgroud)
现在,拟合模型时只需包含参数callbacks = [early_stop,checkpoint]。它将保存具有最低验证损失的权重。
model.fit_generator(X_train, Y_train, validation_data=(X_val, Y_val),
callbacks = [early_stop,checkpoint])
Run Code Online (Sandbox Code Playgroud)
如果您也想保存模型架构,则需要将模型序列化为 JSON:
model_json = model.to_json()
with open("model.json", "w") as json_file:
json_file.write(model_json)
Run Code Online (Sandbox Code Playgroud)
最后加载模型的架构和权重:
# load json and create model
json_file = open('model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
# load weights into new model
loaded_model.load_weights("model_best_weights.h5")
print("Loaded model from disk")
# evaluate loaded model on test data
loaded_model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
score = loaded_model.evaluate(X, Y, verbose=0)
Run Code Online (Sandbox Code Playgroud)
请参考:https://machinelearningmastery.com/save-load-keras-deep-learning-models/
您可以使用以下代码保存模型权重。
model.save_weights('weights.h5')
Run Code Online (Sandbox Code Playgroud)
您可以使用以下代码保存模型的架构:
model.save('architecure.h5')
Run Code Online (Sandbox Code Playgroud)
如果空间不是问题,那么您可以存储所有模型并选择验证损失最低的模型。
或者,您可以在每个时期之后使用回调来评估验证损失,并采用当前验证数据损失最低的模型。这可以通过参考以下链接来完成。在此示例中,只需更改传递给 TestCallback 的数据,并使用一个变量来存储当前的最小验证损失。
class TestCallback(Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs={}):
x, y = self.test_data
loss, acc = self.model.evaluate(x, y, verbose=0)
print('\nTesting loss: {}, acc: {}\n'.format(loss, acc))
model.fit(X_train, Y_train, validation_data=(X_val, Y_val),
callbacks=[TestCallback((X_test, Y_test))])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
7132 次 |
| 最近记录: |