保存 model.fit 不同时期的历史记录

Sra*_*mar 5 epoch neural-network keras tensorflow

我正在用 epoch=10 训练我的模型。我再次用 epoch=3 重新训练。又是纪元 5。所以每次我用纪元 = 10, 3, 5 训练模型时。我想结合所有 3 个纪元的历史。例如,让 h1 = model.fit 的历史记录,纪元 = 10,h2 = epoch=3 的 model.fit 历史记录,h3 = epoch=5 的 model.fit 历史记录。

现在在变量 h 中,我想要 h1 + h2 + h3。所有历史记录都将附加到单个变量,以便我可以绘制一些图表。

代码是,

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")


start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=3, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")
Run Code Online (Sandbox Code Playgroud)

Run Code Online (Sandbox Code Playgroud)

小智 8

您可以通过创建一个子类tf.keras.callbacks.Callback并使用该类的对象作为 的回调来实现此功能model.fit

import csv
import tensorflow.keras.backend as K
from tensorflow import keras
import os

model_directory='./xyz' # directory to save model history after every epoch 

class StoreModelHistory(keras.callbacks.Callback):

  def on_epoch_end(self,batch,logs=None):
    if ('lr' not in logs.keys()):
      logs.setdefault('lr',0)
      logs['lr'] = K.get_value(self.model.optimizer.lr)

    if not ('model_history.csv' in os.listdir(model_directory)):
      with open(model_directory+'model_history.csv','a') as f:
        y=csv.DictWriter(f,logs.keys())
        y.writeheader()

    with open(model_directory+'model_history.csv','a') as f:
      y=csv.DictWriter(f,logs.keys())
      y.writerow(logs)


model.fit(...,callbacks=[StoreModelHistory()])
Run Code Online (Sandbox Code Playgroud)

然后您可以加载 csv 文件并绘制模型的损失、学习率、指标等。

import pandas as pd
import matplotlib.pyplot as plt

EPOCH = 10 # number of epochs the model has trained for

history_dataframe = pd.read_csv(model_directory+'model_history.csv',sep=',')


# Plot training & validation loss values
plt.style.use("ggplot")
plt.plot(range(1,EPOCH+1),
         history_dataframe['loss'])
plt.plot(range(1,EPOCH+1),
         history_dataframe['val_loss'],
         linestyle='--')
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
Run Code Online (Sandbox Code Playgroud)