如何绘制精度和多类分类器的召回率?

Joh*_*all 1 python matplotlib roc scikit-learn precision-recall

我正在使用scikit学习,我想绘制精度和召回曲线。我正在使用的分类器是RandomForestClassifier。scikit学习文档中的所有资源均使用二进制分类。另外,我可以为多类绘制ROC曲线吗?

另外,我只找到了支持向量机的多标签,它有一个decision_functionRandomForest没有

sen*_*nce 5

从scikit-learn文档中:

精确调用曲线通常用于二进制分类中,以研究分类器的输出。为了将精确度调用曲线和平均精确度扩展到多类或多标签分类,必须对输出进行二值化。每个标签可以绘制一条曲线,但也可以通过将标签指示符矩阵的每个元素视为二进制预测(微平均)来绘制精确召回曲线。

ROC曲线通常用于二进制分类中以研究分类器的输出。为了将ROC曲线和ROC区域扩展到多类或多标签分类,有必要对输出进行二值化。可以为每个标签绘制一条ROC曲线,但也可以通过将标签指示符矩阵的每个元素视为二进制预测(微平均)来绘制ROC曲线。

因此,您应该对输出进行二值化处理,并考虑每个类的精确率调用和roc曲线。而且,您将predict_proba用来获取类概率。

我将代码分为三部分:

  1. 常规设置,学习和预测
  2. 精确召回曲线
  3. ROC曲线

1.一般设置,学习和预测

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_mldata("MNIST original")
n_classes = len(set(mnist.target))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)
Run Code Online (Sandbox Code Playgroud)

2.精确调用曲线

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i]))
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))

plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

3. ROC曲线

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

  • 为什么我使用 OneVsRestClassifier?RandomForest 不是已经支持多类了吗? (3认同)