在自定义回调中访问验证数据

Enr*_*ndo 18 python metrics keras

我正在安装train_generator,并通过自定义回调,我想在validation_generator上计算自定义指标.如何访问PARAMS validation_stepsvalidation_data 一个自定义的回调中?它不在self.params,也找不到它self.model.这就是我想做的事情.任何不同的方法都会受到欢迎.

model.fit_generator(generator=train_generator,
                    steps_per_epoch=steps_per_epoch,
                    epochs=epochs,
                    validation_data=validation_generator,
                    validation_steps=validation_steps,
                    callbacks=[CustomMetrics()])


class CustomMetrics(keras.callbacks.Callback):

    def on_epoch_end(self, batch, logs={}):        
        for i in validation_steps:
             # features, labels = next(validation_data)
             # compute custom metric: f(features, labels) 
        return
Run Code Online (Sandbox Code Playgroud)

keras:2.1.1

更新

我设法将验证数据传递给自定义回调的构造函数.但是,这会导致令人讨厌的"内核似乎已经死亡.它会自动重启".信息.我怀疑这是否是正确的方法.有什么建议吗?

class CustomMetrics(keras.callbacks.Callback):

    def __init__(self, validation_generator, validation_steps):
        self.validation_generator = validation_generator
        self.validation_steps = validation_steps


    def on_epoch_end(self, batch, logs={}):

        self.scores = {
            'recall_score': [],
            'precision_score': [],
            'f1_score': []
        }

        for batch_index in range(self.validation_steps):
            features, y_true = next(self.validation_generator)            
            y_pred = np.asarray(self.model.predict(features))
            y_pred = y_pred.round().astype(int) 
            self.scores['recall_score'].append(recall_score(y_true[:,0], y_pred[:,0]))
            self.scores['precision_score'].append(precision_score(y_true[:,0], y_pred[:,0]))
            self.scores['f1_score'].append(f1_score(y_true[:,0], y_pred[:,0]))
        return

metrics = CustomMetrics(validation_generator, validation_steps)

model.fit_generator(generator=train_generator,
                    steps_per_epoch=steps_per_epoch,
                    epochs=epochs,
                    validation_data=validation_generator,
                    validation_steps=validation_steps,
                    shuffle=True,
                    callbacks=[metrics],
                    verbose=1)
Run Code Online (Sandbox Code Playgroud)

小智 8

您可以直接迭代 self.validation_data 以在每个 epoch 结束时聚合所有验证数据。如果要计算整个验证数据集的准确率、召回率和 F1:

# Validation metrics callback: validation precision, recall and F1
# Some of the code was adapted from https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2
class Metrics(callbacks.Callback):

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs):
        # 5.4.1 For each validation batch
        for batch_index in range(0, len(self.validation_data)):
            # 5.4.1.1 Get the batch target values
            temp_targ = self.validation_data[batch_index][1]
            # 5.4.1.2 Get the batch prediction values
            temp_predict = (np.asarray(self.model.predict(
                                self.validation_data[batch_index][0]))).round()
            # 5.4.1.3 Append them to the corresponding output objects
            if(batch_index == 0):
                val_targ = temp_targ
                val_predict = temp_predict
            else:
                val_targ = np.vstack((val_targ, temp_targ))
                val_predict = np.vstack((val_predict, temp_predict))

        val_f1 = round(f1_score(val_targ, val_predict), 4)
        val_recall = round(recall_score(val_targ, val_predict), 4)
        val_precis = round(precision_score(val_targ, val_predict), 4)

        self.val_f1s.append(val_f1)
        self.val_recalls.append(val_recall)
        self.val_precisions.append(val_precis)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_f1"] = val_f1
        logs["val_recall"] = val_recall
        logs["val_precis"] = val_precis

        print("— val_f1: {} — val_precis: {} — val_recall {}".format(
                 val_f1, val_precis, val_recall))
        return

valid_metrics = Metrics()
Run Code Online (Sandbox Code Playgroud)

然后您可以将 valid_metrics 添加到回调参数中:

your_model.fit_generator(..., callbacks = [valid_metrics])
Run Code Online (Sandbox Code Playgroud)

请务必将其放在回调的开头,以防您希望其他回调使用这些措施。

  • 有没有办法使用验证数据的预测结果,而不是再次计算它们? (4认同)
  • 在 `def on_epoch_end(self, batch, messages)` 中访问 self.validation 的先决条件是什么?我总是遇到“AttributeError:'Metrics'对象没有属性'validation_data'” (3认同)

W. *_*am 1

我正在锁定同一问题的解决方案,然后我在此处接受的答案中找到了您的解决方案和另一个解决方案。如果第二个解决方案有效,我认为这会比在“纪元结束时”再次迭代彻底的所有验证更好

这个想法是将目标和预测占位符保存在变量中,并通过“批处理结束时”的自定义回调更新变量