yyn*_*uin 7 machine-learning confusion-matrix scikit-learn deep-learning
存在 27 个类别的多分类问题。
y_predict=[0 0 0 20 26 21 21 26 ....]
y_true=[1 10 10 20 26 21 18 26 ...]
Run Code Online (Sandbox Code Playgroud)
名为“answer_vocabulary”的列表存储了每个索引对应的 27 个单词。answer_vocabulary=[0 1 10 11 2 3 农商东住北.....]
cm = 混淆矩阵(y_true=y_true, y_pred=y_predict)
我对混淆矩阵的顺序感到困惑。它是按索引升序排列的吗?如果我想用标签序列=[0 1 2 3 10 11农业商业生活东北...]重新排序混淆矩阵,我该如何实现呢?
这是我尝试绘制混淆矩阵的函数。
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
Run Code Online (Sandbox Code Playgroud)
sklearn 的混淆矩阵不存储有关如何创建矩阵的信息(类排序和标准化):这意味着您必须在创建混淆矩阵后立即使用它,否则信息将丢失。
默认情况下,sklearn.metrics.confusion_matrix(y_true,y_pred)按照类在 y_true 中出现的顺序创建矩阵。
如果将此数据传递给sklearn.metrix.confusion_matrix:
+--------+--------+
| y_true | y_pred |
+--------+--------+
| A | B |
| C | C |
| D | B |
| B | A |
+--------+--------+
Run Code Online (Sandbox Code Playgroud)
Scikit-leart 将创建这个混淆矩阵(省略零):
+-----------+---+---+---+---+
| true\pred | A | C | D | B |
+-----------+---+---+---+---+
| A | | | | 1 |
| C | | 1 | | |
| D | | | | 1 |
| B | 1 | | | |
+-----------+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)
它会将这个 numpy 矩阵返回给您:
+---+---+---+---+
| 0 | 0 | 0 | 1 |
| 0 | 0 | 1 | 0 |
| 0 | 0 | 0 | 1 |
| 1 | 0 | 0 | 0 |
+---+---+---+---+
Run Code Online (Sandbox Code Playgroud)
如果您想选择类或对它们重新排序,您可以将 'labels' 参数传递给confusion_matrix().
对于重新排序:
labels = ['D','C','B','A']
mat = confusion_matrix(true_y,pred_y, labels=labels)
Run Code Online (Sandbox Code Playgroud)
或者,如果您只想关注一些标签(如果您有很多标签,则很有用):
labels = ['A','D']
mat = confusion_matrix(true_y,pred_y, labels=labels)
Run Code Online (Sandbox Code Playgroud)
另外,看看sklearn.metrics.plot_confusion_matrix。它非常适合小班(<100 人)。
如果您有 >100 个类,则需要使用白色来绘制矩阵。