sklearn - 具有多个分数的交叉验证

blu*_*fer 28 python numpy scikit-learn

我想计算不同分类器的交叉验证测试的召回率,精度f度量. scikit-learn附带了cross_val_score,但遗憾的是这种方法不会返回多个值.

我可以通过调用cross_val_score 三次 来计算这样的度量,但效率不高.有没有更好的解决方案?

到现在为止我写了这个函数:

from sklearn import metrics

def mean_scores(X, y, clf, skf):

    cm = np.zeros(len(np.unique(y)) ** 2)
    for i, (train, test) in enumerate(skf):
        clf.fit(X[train], y[train])
        y_pred = clf.predict(X[test])
        cm += metrics.confusion_matrix(y[test], y_pred).flatten()

    return compute_measures(*cm / skf.n_folds)

def compute_measures(tp, fp, fn, tn):
     """Computes effectiveness measures given a confusion matrix."""
     specificity = tn / (tn + fp)
     sensitivity = tp / (tp + fn)
     fmeasure = 2 * (specificity * sensitivity) / (specificity + sensitivity)
     return sensitivity, specificity, fmeasure
Run Code Online (Sandbox Code Playgroud)

它基本上总结了混淆矩阵值,一旦你有假阳性,假阴性等你就可以轻松计算回忆,精度等......但我仍然不喜欢这个解决方案:)

Tom*_*DLT 12

现在在scikit-learn中:cross_validate是一个可以在多个指标上评估模型的新功能.此功能也可用于GridSearchCVRandomizedSearchCV(doc).它最近在master中合并 ,将在v0.19中提供.

scikit-learn doc:

cross_validate功能cross_val_score有两种不同之处:1.它允许指定多个评估指标.2.除了测试分数之外,它还返回包含训练分数,适合时间和分数时间的字典.

典型的用例是:

from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_validate
iris = load_iris()
scoring = ['precision', 'recall', 'f1']
clf = SVC(kernel='linear', C=1, random_state=0)
scores = cross_validate(clf, iris.data, iris.target == 1, cv=5,
                        scoring=scoring, return_train_score=False)
Run Code Online (Sandbox Code Playgroud)

另请参见此示例.


eic*_*erg 6

您提供的解决方案完全符合您的具体功能cross_val_score,完全适合您的情况.这似乎是正确的方法.

cross_val_score接受论证n_jobs=,使评估可并行化.如果这是你需要的东西,你应该考虑使用并行循环替换你的for循环sklearn.externals.joblib.Parallel.

更一般地说,正在讨论关于scikit学习问题跟踪器中多个分数的问题.可在此处找到代表性线程.因此,虽然看起来未来版本的scikit-learn将允许得分者的多个输出,但截至目前,这是不可能的.

一个hacky(免责声明!)解决这个问题的方法是cross_validation.py通过删除条件检查您的分数是否为数字来稍微改变代码.但是,这个建议非常依赖于版本,因此我将其用于版本0.14.

1)在IPython中,输入from sklearn import cross_validation,然后输入cross_validation??.记下显示的文件名并在编辑器中打开它(您可能需要root权限).

2)您将找到此代码,我已经在其中标记了相关的行(1066).它说

    if not isinstance(score, numbers.Number):
        raise ValueError("scoring must return a number, got %s (%s)"
                         " instead." % (str(score), type(score)))
Run Code Online (Sandbox Code Playgroud)

需要删除这些行.为了跟踪曾经有过的东西(如果你想要改回),请用以下内容替换它

    if not isinstance(score, numbers.Number):
        pass
        # raise ValueError("scoring must return a number, got %s (%s)"
        #                 " instead." % (str(score), type(score)))
Run Code Online (Sandbox Code Playgroud)

如果您的得分者返回的内容不会在cross_val_score其他地方窒息,这应该可以解决您的问题.如果是这种情况,请告诉我.