如何绘制混淆矩阵?

min*_*nks 51 python matplotlib matrix scikit-learn text-classification

我正在使用scikit-learn将文本文档(22000)分类为100个类.我使用scikit-learn的混淆矩阵方法来计算混淆矩阵.

model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')
Run Code Online (Sandbox Code Playgroud)

这就是我的混淆矩阵的样子:

[[3962  325    0 ...,    0    0    0]
 [ 250 2765    0 ...,    0    0    0]
 [   2    8   17 ...,    0    0    0]
 ..., 
 [   1    6    0 ...,    5    0    0]
 [   1    1    0 ...,    0    0    0]
 [   9    0    0 ...,    0    0    9]]
Run Code Online (Sandbox Code Playgroud)

但是,我没有收到明确或清晰的情节.有一个更好的方法吗?

bni*_*aul 101

在此输入图像描述

你可以用plt.matshow()而不是plt.imshow()或者你可以使用seaborn模块heatmap(见文档)来绘制混淆矩阵

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
        [3,31,0,0,0,0,0,0,0,0,0], 
        [0,4,41,0,0,0,0,0,0,0,1], 
        [0,1,0,30,0,6,0,0,0,0,1], 
        [0,0,0,0,38,10,0,0,0,0,0], 
        [0,0,0,3,1,39,0,0,0,0,4], 
        [0,2,2,0,4,1,31,0,0,0,2],
        [0,1,0,0,0,0,0,36,0,2,0], 
        [0,0,0,0,0,0,1,5,37,5,1], 
        [3,0,0,0,0,0,0,0,0,39,0], 
        [0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
                  columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)
Run Code Online (Sandbox Code Playgroud)


小智 41

@bninopaul的回答并不完全适合初学者

这是你可以"复制和运行"的代码

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

array = [[13,1,1,0,2,0],
     [3,9,6,0,1,0],
     [0,0,16,2,0,0],
     [0,0,0,13,0,0],
     [0,0,0,0,15,0],
     [0,0,1,0,0,15]]        
df_cm = pd.DataFrame(array, range(6),
                  range(6))
#plt.figure(figsize = (10,7))
sn.set(font_scale=1.4)#for label size
sn.heatmap(df_cm, annot=True,annot_kws={"size": 16})# font size

plt.show()
Run Code Online (Sandbox Code Playgroud)

结果

  • 我不明白为什么这个答案更适合“初学者”?...它与bninopaul的基本相同。 (6认同)
  • conf 矩阵是 *初学者大小* @DavidSkarbrevik ;) (4认同)
  • 只是添加,对于自定义的`x`和`y`标签,用这样的东西替换`df_cm`行:`df_cm = pd.DataFrame(array,index = ["stage 1","stage 2","stage 3 ","stagte 4"],columns = ["stage 1","stage 2","stage 3","stagte 4"]) (2认同)

Wag*_*ano 17

如果你想在混淆矩阵中有更多数据,包括" totals column "和" totals line ",以及每个单元格中的百分比(%),就像matlab默认(见下图)

在此输入图像描述

包括热图和其他选项......

你应该玩上面的模块,在github中共享; )

https://github.com/wcipriano/pretty-print-confusion-matrix


这个模块可以轻松完成你的任务并产生上面的输出,有很多参数来定制你的CM: 在此输入图像描述