每个纪元后的自定义回调以记录某些信息

Bas*_*asj 6 python machine-learning keras

我知道如何在每个纪元后保存模型

savemodel = ModelCheckpoint(filepath='models/model_{epoch:02d}-{loss:.2f}.h5')
model.fit(X, Y, batch_size=4, epochs=32, verbose=1, callbacks=[savemodel])
Run Code Online (Sandbox Code Playgroud)

如何使用自定义回调函数来记录某些信息:

def write_metrics(): 
    with open('log.txt', 'a') as f:  # append to the log file
        f.write('{epoch:02d}: loss = {loss:.1f}')

model.fit(X, Y, batch_size=4, epochs=32, verbose=1, callbacks=[savemodel, write_metrics])
Run Code Online (Sandbox Code Playgroud)

使用此代码将无法工作,因为{loss}{epoch}未在 中定义f.write('{epoch:02d}: loss = {loss:.1f}')

Bas*_*asj 6

这是通过子类化的解决方案Callback

from keras.callbacks import Callback

class MyLogger(Callback):
    def on_epoch_end(self, epoch, logs=None):
        with open('log.txt', 'a+') as f:
            f.write('%02d %.3f\n' % (epoch, logs['loss']))
Run Code Online (Sandbox Code Playgroud)

然后

mylogger = MyLogger()
model.fit(X, Y, batch_size=32, epochs=32, verbose=1, callbacks=[mylogger])
Run Code Online (Sandbox Code Playgroud)

甚至

model.fit(X, Y, batch_size=32, epochs=32, verbose=1, callbacks=[MyLogger()])
Run Code Online (Sandbox Code Playgroud)

  • 我在某些地方看到“logs={}”,即字典和“logs=None”,如您提供的示例中所示,您能否解释一下“logs”参数以及在您的示例中如何获得损失使用“日志['损失']”? (2认同)