小编yyn*_*uin的帖子

如何知道 scikit-learn 混淆矩阵的标签顺序并更改它

存在 27 个类别的多分类问题。

y_predict=[0 0 0 20 26 21 21 26 ....]

y_true=[1 10 10 20 26 21 18 26 ...]  
Run Code Online (Sandbox Code Playgroud)

名为“answer_vocabulary”的列表存储了每个索引对应的 27 个单词。answer_vocabulary=[0 1 10 11 2 3 农商东住北.....]

cm = 混淆矩阵(y_true=y_true, y_pred=y_predict)

我对混淆矩阵的顺序感到困惑。它是按索引升序排列的吗?如果我想用标签序列=[0 1 2 3 10 11农业商业生活东北...]重新排序混淆矩阵,我该如何实现呢?

这是我尝试绘制混淆矩阵的函数。

def plot_confusion_matrix(cm, classes,
                        normalize=False,
                        title='Confusion matrix',
                        cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize: …
Run Code Online (Sandbox Code Playgroud)

machine-learning confusion-matrix scikit-learn deep-learning

7
推荐指数
1
解决办法
8626
查看次数