我想在另一个回调(例如 EarlyStopping 或 ModelCheckpoint)中使用来自回调的自定义指标。但我需要以某种方式保存/存储/记录这个自定义指标,以便其他回调可以访问它?
\n\n我有:
\n\nclass Metrics(keras.callbacks.Callback):\n def on_train_begin(self, logs={}):\n\n self.precision = []\n self.f1s = []\n self.prc=0\n self.f1s=0\n\n def on_epoch_end(self, epoch, logs={}):\n score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))\n predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))\n targ = self.validation_data[2]\n\n predict = (predict < 0.5).astype(np.float)\n\n\n self.prc=sklm.precision_score(targ, predict)\n self.f1s=sklm.f1_score(targ, predict)\n self.precision.append(prc)\n self.f1s.append(f1s)\n\n print("\xe2\x80\x94 val_f1: %f \xe2\x80\x94 val_precision: %f" %(self.f1s, self.prc))\n return\nRun Code Online (Sandbox Code Playgroud)\n\n现在,
\n\nmetrics = Metrics()\n\nes = EarlyStopping(monitor=metrics.prc, mode=\'max\', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)\n\nmodel.compile(loss=contrastive_loss, optimizer=adam)\nmodel.fit([train_sen1, train_sen2], train_labels,\n batch_size=512,\n epochs=20,callbacks=[metrics,es],\n validation_data=([dev_sen1, dev_sen2], dev_labels))\nRun Code Online (Sandbox Code Playgroud)\n\n不起作用,因为 Earlystopping 不知道自定义精度指标?
\n\n有人知道这个回调日志语句吗?我可以在那里保存我的指标吗?
\n小智 7
要了解这里到底发生了什么,您必须去 github 上检查 EarlyStopping 和 ModelCheckpoint 类的源代码。你可以在这里找到它。
\n\n您代码中的问题是您没有更新“on_epoch_end”函数中的“logs”字典。EarlyStopping 和 ModelCheckpoint 类就是在该字典中查找您定义为“监视器”的内容。
\n\n因此,在您的情况下,如果您要使用精度分数作为监视器,您的代码应如下所示:
\n\nclass Metrics(keras.callbacks.Callback):\n def on_train_begin(self, logs={}):\n\n self.precision = []\n self.f1scores = []\n self.prc=0\n self.f1s=0\n\n def on_epoch_end(self, epoch, logs={}):\n score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))\n predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))\n targ = self.validation_data[2]\n\n predict = (predict < 0.5).astype(np.float)\n\n\n self.prc=sklm.precision_score(targ, predict)\n self.f1s=sklm.f1_score(targ, predict)\n self.precision.append(prc)\n self.f1scores.append(f1s)\n\n #Here is where I update the logs dictionary:\n logs["prc"]=self.prc\n logs["f1s"]=self.f1s\n\n print("\xe2\x80\x94 val_f1: %f \xe2\x80\x94 val_precision: %f" %(self.f1s, self.prc))\nRun Code Online (Sandbox Code Playgroud)\n\n然后,您可以在 CheckpointModel 和 EarlyStopping 中调用这些自定义指标。但请确保在 fit_generator 中以正确的顺序放置这些回调:指标应放在第一位,否则当您运行 EarlyStopping 时,您的日志将不会更新。
\n\nmetrics = Metrics()\n\nes = EarlyStopping(monitor="prc", mode=\'max\', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)\n\nmodel.compile(loss=contrastive_loss, optimizer=adam)\nmodel.fit([train_sen1, train_sen2], train_labels,\n batch_size=512,\n epochs=20,callbacks=[metrics,es],\n validation_data=([dev_sen1, dev_sen2], dev_labels))\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
2339 次 |
| 最近记录: |