使用没有分类器的 scikit-learn 绘制混淆矩阵

Iri*_*ina 16 python confusion-matrix scikit-learn

我有一个用sklearn.metrics.confusion_matrix.

现在,我想用 来绘制它sklearn.metrics.plot_confusion_matrix,但第一个参数是经过训练的分类器,如文档中所指定。问题是我没有分类器;结果是通过手动计算获得的。

是否仍然可以通过 scikit-learn 在一行中绘制混淆矩阵,还是必须使用 matplotlib 自己编写代码?

Viv*_*mar 31

您可以plot_confusion_matrix直接导入这一事实表明您安装了最新版本的 scikit-learn (0.22)。所以你可以看看源代码,plot_confusion_matrix()看看它是如何使用estimator.

来自此处最新来源,估算器用于:

  1. 使用计算混淆矩阵 confusion_matrix
  2. 获取标签(y 的唯一值,对应于混淆矩阵中的 0,1,2..)

所以如果你已经有了这两件事,你只需要以下部分:

import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=display_labels)


# NOTE: Fill all variables here with default values of the plot_confusion_matrix
disp = disp.plot(include_values=include_values,
                 cmap=cmap, ax=ax, xticks_rotation=xticks_rotation)

plt.show()
Run Code Online (Sandbox Code Playgroud)

请查看注释中的注释。

对于旧版本,您可以在此处查看matplotlib 部分的编码方式