如何计算Keras的精度和召回率

Jim*_* Du 40 python precision precision-recall keras

我正在用Keras 2.02(带有Tensorflow后端)构建一个多类分类器,我不知道如何计算Keras中的精度和召回率.请帮我.

Yas*_*nov 38

Python包keras-metrics可能对此有用(我是包的作者).

import keras
import keras_metrics

model = models.Sequential()
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
model.add(keras.layers.Dense(1, activation="softmax"))

model.compile(optimizer="sgd",
              loss="binary_crossentropy",
              metrics=[keras_metrics.precision(), keras_metrics.recall()])
Run Code Online (Sandbox Code Playgroud)

  • 提及您是此套餐的作者可能是一个好主意,这既是为了充分披露您的联盟,也是为了让您的答案更具可信度. (3认同)
  • 它在Keras 2.2.2和keras-metrics 0.0.5中给出了精度和查全率= 0.000。我使用了与上面相同的代码。但是,当我使用@Christian发布的解决方案时,它返回值。 (2认同)

小智 30

从Keras 2.0开始,精确和召回从主分支中删除.你必须自己实现它们.请按照本指南创建自定义指标:此处.

精度和回忆方程可以在这里找到

或者从keras重用代码,它已被删除之前,在这里.

已删除指标,因为它们是批处理的,因此值可能正确也可能不正确.

  • 他们为什么被删除? (16认同)
  • 即使链接中断,请确保您的答案有用. (3认同)

vog*_*gdb 5

我的回答是基于Keras GH问题评论。它为一个热编码的分类任务计算验证精度和在每个时期的召回率。另外,请查看此SO答案以了解如何使用keras.backend功能。

import keras as keras
import numpy as np
from keras.optimizers import SGD
from sklearn.metrics import precision_score, recall_score

model = keras.models.Sequential()
# ...
sgd = SGD(lr=0.001, momentum=0.9)
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])


class Metrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self._data = []

    def on_epoch_end(self, batch, logs={}):
        X_val, y_val = self.validation_data[0], self.validation_data[1]
        y_predict = np.asarray(model.predict(X_val))

        y_val = np.argmax(y_val, axis=1)
        y_predict = np.argmax(y_predict, axis=1)

        self._data.append({
            'val_recall': recall_score(y_val, y_predict),
            'val_precision': precision_score(y_val, y_predict),
        })
        return

    def get_data(self):
        return self._data


metrics = Metrics()
history = model.fit(X_train, y_train, epochs=100, validation_data=(X_val, y_val), callbacks=[metrics])
metrics.get_data()
Run Code Online (Sandbox Code Playgroud)