显示分类错误的实例

Oph*_*lia 2 python scikit-learn

我正在使用Scikit-learn构建SVM分类器...并且在运行分类器时..我想通过检查分类错误的实例并试图找出分类错误的原因来提高分类器的准确性...所以有没有办法显示分类错误的实例?

Rol*_*Max 5

有没有办法显示分类错误的实例?

是的,您需要在这里和那里做一些索引。下面是一个示例,但是技术细节将取决于分类器的输入和输出方式。

最简单的情况是输出是单个值时,因此您可以轻松比较实例是否已正确分类。例如,让我们收集一些数据并训练一个二进制分类器:

>>> from sklearn import cross_validation, datasets, svm
>>> X, y = datasets.make_classification()
>>> X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y)
>>> clf = svm.LinearSVC().fit(X_train, y_train)
>>> y_pred = clf.predict(X_test)
Run Code Online (Sandbox Code Playgroud)

您可以直接比较y_testy_pred因为输出是单个值。如果您正在训练一个多类模型,那么您将无法进行直接比较,而应该逐个类地进行比较。

>>> misclassified_samples = X_test[y_test != y_pred]
Run Code Online (Sandbox Code Playgroud)

如果需要,您也可以将布尔掩码转换为索引。

>>> import numpy as np
>>> np.flatnonzero(y_test != y_pred)
array([ 0, 20, 22])
Run Code Online (Sandbox Code Playgroud)