Sra*_*mar 5 epoch neural-network keras tensorflow
我正在用 epoch=10 训练我的模型。我再次用 epoch=3 重新训练。又是纪元 5。所以每次我用纪元 = 10, 3, 5 训练模型时。我想结合所有 3 个纪元的历史。例如,让 h1 = model.fit 的历史记录,纪元 = 10,h2 = epoch=3 的 model.fit 历史记录,h3 = epoch=5 的 model.fit 历史记录。
现在在变量 h 中,我想要 h1 + h2 + h3。所有历史记录都将附加到单个变量,以便我可以绘制一些图表。
代码是,
start_time = time.time()
model.fit(x=X_train, y=y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])
end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")
start_time = time.time()
model.fit(x=X_train, y=y_train, batch_size=32, epochs=3, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])
end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")
start_time = time.time()
model.fit(x=X_train, y=y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])
end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")
Run Code Online (Sandbox Code Playgroud)
Run Code Online (Sandbox Code Playgroud)
小智 8
您可以通过创建一个子类tf.keras.callbacks.Callback
并使用该类的对象作为 的回调来实现此功能model.fit
。
import csv
import tensorflow.keras.backend as K
from tensorflow import keras
import os
model_directory='./xyz' # directory to save model history after every epoch
class StoreModelHistory(keras.callbacks.Callback):
def on_epoch_end(self,batch,logs=None):
if ('lr' not in logs.keys()):
logs.setdefault('lr',0)
logs['lr'] = K.get_value(self.model.optimizer.lr)
if not ('model_history.csv' in os.listdir(model_directory)):
with open(model_directory+'model_history.csv','a') as f:
y=csv.DictWriter(f,logs.keys())
y.writeheader()
with open(model_directory+'model_history.csv','a') as f:
y=csv.DictWriter(f,logs.keys())
y.writerow(logs)
model.fit(...,callbacks=[StoreModelHistory()])
Run Code Online (Sandbox Code Playgroud)
然后您可以加载 csv 文件并绘制模型的损失、学习率、指标等。
import pandas as pd
import matplotlib.pyplot as plt
EPOCH = 10 # number of epochs the model has trained for
history_dataframe = pd.read_csv(model_directory+'model_history.csv',sep=',')
# Plot training & validation loss values
plt.style.use("ggplot")
plt.plot(range(1,EPOCH+1),
history_dataframe['loss'])
plt.plot(range(1,EPOCH+1),
history_dataframe['val_loss'],
linestyle='--')
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
8248 次 |
最近记录: |