如何通过 EarlyStopping 或 ModelCheckpoint 使用回调中的自定义指标?

1 callback keras

我想在另一个回调(例如 EarlyStopping 或 ModelCheckpoint)中使用来自回调的自定义指标。但我需要以某种方式保存/存储/记录这个自定义指标,以便其他回调可以访问它?

\n\n

我有:

\n\n
class Metrics(keras.callbacks.Callback):\n    def on_train_begin(self, logs={}):\n\n        self.precision = []\n        self.f1s = []\n        self.prc=0\n        self.f1s=0\n\n    def on_epoch_end(self, epoch, logs={}):\n        score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))\n        predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))\n        targ = self.validation_data[2]\n\n        predict = (predict < 0.5).astype(np.float)\n\n\n        self.prc=sklm.precision_score(targ, predict)\n        self.f1s=sklm.f1_score(targ, predict)\n        self.precision.append(prc)\n        self.f1s.append(f1s)\n\n        print("\xe2\x80\x94 val_f1: %f \xe2\x80\x94 val_precision: %f" %(self.f1s, self.prc))\n        return\n
Run Code Online (Sandbox Code Playgroud)\n\n

现在,

\n\n
metrics = Metrics()\n\nes = EarlyStopping(monitor=metrics.prc, mode=\'max\', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)\n\nmodel.compile(loss=contrastive_loss, optimizer=adam)\nmodel.fit([train_sen1, train_sen2], train_labels,\n          batch_size=512,\n          epochs=20,callbacks=[metrics,es],\n          validation_data=([dev_sen1, dev_sen2], dev_labels))\n
Run Code Online (Sandbox Code Playgroud)\n\n

不起作用,因为 Earlystopping 不知道自定义精度指标?

\n\n

有人知道这个回调日志语句吗?我可以在那里保存我的指标吗?

\n

小智 7

要了解这里到底发生了什么,您必须去 github 上检查 EarlyStopping 和 ModelCheckpoint 类的源代码。你可以在这里找到它。

\n\n

您代码中的问题是您没有更新“on_epoch_end”函数中的“logs”字典。EarlyStopping 和 ModelCheckpoint 类就是在该字典中查找您定义为“监视器”的内容。

\n\n

因此,在您的情况下,如果您要使用精度分数作为监视器,您的代码应如下所示:

\n\n
class Metrics(keras.callbacks.Callback):\n    def on_train_begin(self, logs={}):\n\n        self.precision = []\n        self.f1scores = []\n        self.prc=0\n        self.f1s=0\n\n    def on_epoch_end(self, epoch, logs={}):\n        score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))\n        predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))\n        targ = self.validation_data[2]\n\n        predict = (predict < 0.5).astype(np.float)\n\n\n        self.prc=sklm.precision_score(targ, predict)\n        self.f1s=sklm.f1_score(targ, predict)\n        self.precision.append(prc)\n        self.f1scores.append(f1s)\n\n        #Here is where I update the logs dictionary:\n        logs["prc"]=self.prc\n        logs["f1s"]=self.f1s\n\n        print("\xe2\x80\x94 val_f1: %f \xe2\x80\x94 val_precision: %f" %(self.f1s, self.prc))\n
Run Code Online (Sandbox Code Playgroud)\n\n

然后,您可以在 CheckpointModel 和 EarlyStopping 中调用这些自定义指标。但请确保在 fit_generator 中以正确的顺序放置这些回调:指标应放在第一位,否则当您运行 EarlyStopping 时,您的日志将不会更新。

\n\n
metrics = Metrics()\n\nes = EarlyStopping(monitor="prc", mode=\'max\', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)\n\nmodel.compile(loss=contrastive_loss, optimizer=adam)\nmodel.fit([train_sen1, train_sen2], train_labels,\n          batch_size=512,\n          epochs=20,callbacks=[metrics,es],\n          validation_data=([dev_sen1, dev_sen2], dev_labels))\n
Run Code Online (Sandbox Code Playgroud)\n