如何知道 scikit-learn 混淆矩阵的标签顺序并更改它

yyn*_*uin 7 machine-learning confusion-matrix scikit-learn deep-learning

存在 27 个类别的多分类问题。

y_predict=[0 0 0 20 26 21 21 26 ....]

y_true=[1 10 10 20 26 21 18 26 ...]  
Run Code Online (Sandbox Code Playgroud)

名为“answer_vocabulary”的列表存储了每个索引对应的 27 个单词。answer_vocabulary=[0 1 10 11 2 3 农商东住北.....]

cm = 混淆矩阵(y_true=y_true, y_pred=y_predict)

我对混淆矩阵的顺序感到困惑。它是按索引升序排列的吗?如果我想用标签序列=[0 1 2 3 10 11农业商业生活东北...]重新排序混淆矩阵,我该如何实现呢?

这是我尝试绘制混淆矩阵的函数。

def plot_confusion_matrix(cm, classes,
                        normalize=False,
                        title='Confusion matrix',
                        cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
Run Code Online (Sandbox Code Playgroud)

Iñi*_*lez 9

sklearn 的混淆矩阵不存储有关如何创建矩阵的信息(类排序和标准化):这意味着您必须在创建混淆矩阵后立即使用它,否则信息将丢失。

默认情况下,sklearn.metrics.confusion_matrix(y_true,y_pred)按照类在 y_true 中出现的顺序创建矩阵。

如果将此数据传递给sklearn.metrix.confusion_matrix

+--------+--------+
| y_true | y_pred |
+--------+--------+
| A      | B      |
| C      | C      |
| D      | B      |
| B      | A      |
+--------+--------+
Run Code Online (Sandbox Code Playgroud)

Scikit-leart 将创建这个混淆矩阵(省略零):

+-----------+---+---+---+---+
| true\pred | A | C | D | B | 
+-----------+---+---+---+---+
| A         |   |   |   | 1 |
| C         |   | 1 |   |   |
| D         |   |   |   | 1 |
| B         | 1 |   |   |   |
+-----------+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)

它会将这个 numpy 矩阵返回给您:

+---+---+---+---+
| 0 | 0 | 0 | 1 |
| 0 | 0 | 1 | 0 |
| 0 | 0 | 0 | 1 |
| 1 | 0 | 0 | 0 |
+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)

如果您想选择类或对它们重新排序,您可以将 'labels' 参数传递给confusion_matrix().

对于重新排序:

labels = ['D','C','B','A']
mat = confusion_matrix(true_y,pred_y, labels=labels)

Run Code Online (Sandbox Code Playgroud)

或者,如果您只想关注一些标签(如果您有很多标签,则很有用):

labels = ['A','D']
mat = confusion_matrix(true_y,pred_y, labels=labels)
Run Code Online (Sandbox Code Playgroud)

另外,看看sklearn.metrics.plot_confusion_matrix。它非常适合小班(<100 人)。

如果您有 >100 个类,则需要使用白色来绘制矩阵。