Vu *_*Anh 10 python matplotlib
今天我试图从我的分类模型中绘制混淆矩阵.
在一些页面搜索后,我发现matshow从中pyplot可以帮助我.
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title(title)
fig.colorbar(cax)
if labels:
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
Run Code Online (Sandbox Code Playgroud)
如果我的标签很少,它的效果很好
y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a']
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)
Run Code Online (Sandbox Code Playgroud)
但如果我有很多标签,一些标签不能正确显示
y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)
Run Code Online (Sandbox Code Playgroud)
我的问题是如何在matshow情节中显示所有标签?我尝试了类似的东西,fontdict但它仍然没有用
小智 6
您可以使用xticks方法来指定标签。您的函数将如下所示(修改上述答案中的函数):
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)
# I also added cmap=cmap here, to make use of the
# colormap you specify in the function call
cax = ax.matshow(cm,cmap=cmap)
plt.title(title)
fig.colorbar(cax)
if labels:
plt.xticks(range(len(labels)), labels)
plt.yticks(range(len(labels)), labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusionmatrix.png')
Run Code Online (Sandbox Code Playgroud)
您可以使用该matplotlib.ticker模块控制刻度的频率。
在这种情况下,您想在的每一个倍数上设置一个勾号1,因此我们可以使用MultipleLocator
在致电之前添加这两行plt.show():
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
Run Code Online (Sandbox Code Playgroud)
它将为您y_true和中的每个字母产生一个刻度和标签y_pred。
我还更改了您的matshow调用,以利用您在函数调用中指定的颜色图:
cax = ax.matshow(cm,cmap=cmap)
Run Code Online (Sandbox Code Playgroud)
为了完整起见,您的整个功能将如下所示:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as ticker
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)
# I also added cmap=cmap here, to make use of the
# colormap you specify in the function call
cax = ax.matshow(cm,cmap=cmap)
plt.title(title)
fig.colorbar(cax)
if labels:
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusionmatrix.png')
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
6032 次 |
| 最近记录: |