绘制多标签分类 Python 的混淆矩阵

use*_*437 8 python decision-tree confusion-matrix scikit-learn multilabel-classification

我正在寻找可以帮助我绘制混淆矩阵的人。我在大学的学期论文中需要这个。但是我在编程方面的经验很少。

在图片中,您可以看到分类报告以及 myy_testX_testmy 案例的结构dtree_predictions

如果有人可以帮助我,我会很高兴,因为我尝试了很多事情,但我只是没有得到解决方案,只有错误消息。

X_train, X_test, y_train, y_test = train_test_split(X, Y_profile, test_size = 0.3, random_state = 30)

dtree_model = DecisionTreeClassifier().fit(X_train,y_train)
dtree_predictions = dtree_model.predict(X_test)

print(metrics.classification_report(dtree_predictions, y_test))
              precision    recall  f1-score   support

       0       1.00      1.00      1.00       222
       1       1.00      1.00      1.00       211
       2       1.00      1.00      1.00       229
       3       0.96      0.97      0.96       348
       4       0.89      0.85      0.87        93
       5       0.86      0.86      0.86       105
       6       0.94      0.93      0.94       116
       7       1.00      1.00      1.00       364
       8       0.99      0.97      0.98       139
       9       0.98      0.99      0.99       159
      10       0.97      0.96      0.97       189
      11       0.92      0.92      0.92       124
      12       0.92      0.92      0.92       119
      13       0.95      0.96      0.95       230
      14       0.98      0.96      0.97       452
      15       0.91      0.96      0.93       210

micro avg       0.96      0.96      0.96      3310
macro avg       0.95      0.95      0.95      3310
weighted avg    0.97      0.96      0.96      3310
samples avg     0.96      0.96      0.96      3310
Run Code Online (Sandbox Code Playgroud)

接下来我打印多标签混淆矩阵的度量

from sklearn.metrics import multilabel_confusion_matrix
multilabel_confusion_matrix(y_test, dtree_predictions)

array([[[440,   0],
    [  0, 222]],

   [[451,   0],
    [  0, 211]],

   [[433,   0],
    [  0, 229]],

   [[299,  10],
    [ 15, 338]],

   [[559,  14],
    [ 10,  79]],

   [[542,  15],
    [ 15,  90]],

   [[539,   8],
    [  7, 108]],

   [[297,   0],
    [  1, 364]],

   [[522,   4],
    [  1, 135]],

   [[500,   1],
    [  3, 158]],

   [[468,   8],
    [  5, 181]],

   [[528,  10],
    [ 10, 114]],

   [[534,   9],
    [  9, 110]],

   [[420,   9],
    [ 12, 221]],

   [[201,  19],
    [  9, 433]],

   [[433,   9],
    [ 19, 201]]])
Run Code Online (Sandbox Code Playgroud)

和的结构y_testdtree_predictons

print(dtree_predictions)
print(dtree_predictions.shape)

[[0. 0. 1. ... 0. 1. 0.]
[1. 0. 0. ... 0. 1. 0.]
[0. 0. 1. ... 0. 1. 0.]
 ...
[1. 0. 0. ... 0. 0. 1.]
[0. 1. 0. ... 1. 0. 1.]
[0. 1. 0. ... 1. 0. 1.]]
(662, 16)

print(y_test)

      Cooler close to failure  Cooler reduced effiency  Cooler full    effiency  \
1985                      0.0                      0.0                   1.0   
322                       1.0                      0.0                   0.0   
2017                      0.0                      0.0                   1.0   
1759                      0.0                      0.0                   1.0   
1602                      0.0                      0.0                     1.0   
...                       ...                      ...                      ...   
128                       1.0                      0.0                   0.0   
321                       1.0                      0.0                   0.0   
53                        1.0                      0.0                   0.0   
859                       0.0                      1.0                     0.0   
835                       0.0                      1.0                       0.0   

  valve optimal  valve small lag  valve severe lag  \
1985            0.0              0.0               0.0   
322             0.0              1.0               0.0   
2017            1.0              0.0               0.0   
1759            0.0              0.0               0.0   
1602            1.0              0.0               0.0   
...             ...              ...               ...   
128             1.0              0.0               0.0   
321             0.0              1.0               0.0   
53              1.0              0.0               0.0   
859             1.0              0.0               0.0   
835             1.0              0.0               0.0   

  valve close to failure  pump no leakage  pump weak leakage  \
1985                     1.0              0.0                1.0   
322                      0.0              1.0                0.0   
2017                     0.0              0.0                1.0   
1759                     1.0              1.0                0.0   
1602                     0.0              1.0                0.0   
...                      ...              ...                ...   
128                      0.0              1.0                0.0   
321                      0.0              1.0                0.0   
53                       0.0              1.0                0.0   
859                      0.0              1.0                0.0   
835                      0.0              1.0                0.0   

  pump severe leakage  accu optimal pressure  \
1985                  0.0                    0.0   
322                   0.0                    1.0   
2017                  0.0                    0.0   
1759                  0.0                    1.0   
1602                  0.0                    0.0   
...                   ...                    ...   
128                   0.0                    1.0   
321                   0.0                    1.0   
53                    0.0                    1.0   
859                   0.0                    0.0   
835                   0.0                    0.0   

  accu slightly reduced pressure  accu severly reduced pressure  \
1985                             0.0                            1.0   
322                              0.0                            0.0   
2017                             0.0                            1.0   
1759                             0.0                            0.0   
1602                             0.0                            0.0   
...                              ...                            ...   
128                              0.0                            0.0   
321                              0.0                            0.0   
53                               0.0                            0.0   
859                              0.0                            0.0   
835                              0.0                            0.0   

  accu close to failure  stable flag stable  stable flag not stable  
1985                    0.0                 1.0                     0.0  
322                     0.0                 1.0                     0.0  
2017                    0.0                 1.0                     0.0  
1759                    0.0                 1.0                     0.0  
1602                    1.0                 0.0                     1.0  
...                     ...                 ...                     ...  
128                     0.0                 0.0                     1.0  
321                     0.0                 1.0                     0.0  
53                      0.0                 0.0                     1.0  
859                     1.0                 0.0                     1.0  
835                     1.0                 0.0                     1.0  

[662 rows x 16 columns]
Run Code Online (Sandbox Code Playgroud)

col*_*ldy 10

通常,混淆矩阵通过热图进行可视化。在github 中还创建了一个函数来漂亮地打印混淆矩阵。受此启发,我适应了多标签场景,其中每个具有二元预测 (Y, N) 的类都被添加到矩阵中并通过热图进行可视化。

这是从发布的代码中获取一些输出的示例:

为每个标签获得的混淆矩阵变成了一个二元分类问题。

多标签混淆矩阵将 TN 置于 (0,0) 并将 TP 置于 (1,1) 位置,感谢@Kenneth Witham 指出。
import numpy as np

vis_arr = np.asarray([[[440,   0],
    [  0, 222]],

   [[451,   0],
    [  0, 211]],

   [[433,   0],
    [  0, 229]],

   [[299,  10],
    [ 15, 338]],

   [[559,  14],
    [ 10,  79]],

   [[542,  15],
    [ 15,  90]],

   [[539,   8],
    [  7, 108]],

   [[297,   0],
    [  1, 364]],

   [[522,   4],
    [  1, 135]],

   [[500,   1],
    [  3, 158]],

   [[468,   8],
    [  5, 181]],

   [[528,  10],
    [ 10, 114]],

   [[534,   9],
    [  9, 110]],

   [[420,   9],
    [ 12, 221]],

   [[201,  19],
    [  9, 433]],

   [[433,   9],
    [ 19, 201]]])
Run Code Online (Sandbox Code Playgroud)

手动创建的类标签 c0 到 c15。

labels = ["".join("c" + str(i)) for i in range(0, 16)]
Run Code Online (Sandbox Code Playgroud)

混淆矩阵自适应的多标签可视化

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def print_confusion_matrix(confusion_matrix, axes, class_label, class_names, fontsize=14):

    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names,
    )

    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cbar=False, ax=axes)
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    axes.set_ylabel('True label')
    axes.set_xlabel('Predicted label')
    axes.set_title("Confusion Matrix for the class - " + class_label)
Run Code Online (Sandbox Code Playgroud)

更新多标签分类可视化

将基本混淆矩阵扩展到以标题为每个类的子图网格的绘图。这里的 [Y, N] 是定义的类标签,可以扩展。

fig, ax = plt.subplots(4, 4, figsize=(12, 7))
    
    for axes, cfs_matrix, label in zip(ax.flatten(), vis_arr, labels):
        print_confusion_matrix(cfs_matrix, axes, label, ["N", "Y"])
    
    fig.tight_layout()
    plt.show()
Run Code Online (Sandbox Code Playgroud)

注意:此图是基于关于混淆矩阵的维基文章构建的

输出:

在此处输入图片说明

  • 如果它有效,您可以接受它作为答案。您提供的链接的作用有所不同。多标签转换为多类,然后可视化。此外,可视化是从 16 个类别中每个类别获得的概率。为了以这种方式进行可视化,您可以使用 sklearn `predict_proba()` 并将其用于我在答案中指出的函数。 (2认同)

Ven*_*lam 6

你可以使用ConfusionMatrixDisplay的选项sklearn.metrics

例子:

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_multilabel_classification
from sklearn.tree import DecisionTreeClassifier

X, y = make_multilabel_classification(n_samples=1000,
                                      n_classes=15, random_state=42)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=42)

tree = DecisionTreeClassifier(random_state=42).fit(X_train, y_train)

y_pred = tree.predict(X_test)

f, axes = plt.subplots(3, 5, figsize=(25, 15))
axes = axes.ravel()
for i in range(15):
    disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
                                                   y_pred[:, i]),
                                  display_labels=[0, i])
    disp.plot(ax=axes[i], values_format='.4g')
    disp.ax_.set_title(f'class {i}')
    if i<10:
        disp.ax_.set_xlabel('')
    if i%5!=0:
        disp.ax_.set_ylabel('')
    disp.im_.colorbar.remove()

plt.subplots_adjust(wspace=0.10, hspace=0.1)
f.colorbar(disp.im_, ax=axes)
plt.show()

Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明