Sha*_*kib 8 python machine-learning confusion-matrix scikit-learn
当我尝试制作 CNN 模型的混淆矩阵时,我遇到了一些问题。当我运行代码时,它返回一些错误,例如:
print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))
Traceback (most recent call last):
File "<ipython-input-102-82d46efe536a>", line 1, in <module>
print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))
File "G:\anaconda_installation_file\lib\site-packages\sklearn\metrics\classification.py", line 1543, in classification_report
"parameter".format(len(labels), len(target_names))
ValueError: Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter
Run Code Online (Sandbox Code Playgroud)
我已经搜索过解决这个问题的方法,但仍然没有得到完美的解决方案。我在这个领域完全陌生,有人可以帮助我吗?谢谢。
from sklearn.metrics import classification_report,confusion_matrix
import itertools
Y_pred = model.predict(X_test)
print(Y_pred)
y_pred = np.argmax(Y_pred, axis=1)
print(y_pred)
target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']
print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))
Run Code Online (Sandbox Code Playgroud)
问题是你有 6 个标签名称:'class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)'
但是你的confusion_matrix中只有4个类,当你打印: print(y_pred): 你会得到带有数字的东西0,1,2,3,或者当你print(y_test)从 得到数字时0,1,2,3,它应该有助于删除:
print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))
从你的代码来看,不知何故你没有 6 个预测/测试类。
这里还有一个如何绘制混淆矩阵的示例:How can Iplot a混淆矩阵?
| 归档时间: |
|
| 查看次数: |
22127 次 |
| 最近记录: |