Col*_*lis 15 python classification threshold roc scikit-learn
我使用scikit-learn训练了ExtraTreesClassifier(gini索引),它非常适合我的需求.准确性不是很好,但使用10倍交叉验证,AUC为0.95.我想在我的工作中使用这个分类器.我对ML很新,所以如果我问你一些概念错误的话,请原谅我.
我绘制了一些ROC曲线,通过它,我似乎有一个特定的阈值,我的分类器开始表现良好.我想在拟合的分类器上设置这个值,所以每次我调用预测时,分类器都会使用该阈值,我可以相信FP和TP的速率.
我也来到这篇文章(scikit .predict()默认阈值),其中声明阈值不是分类器的通用概念.但由于ExtraTreesClassifier的方法是predict_proba,并且ROC曲线也与thresdholds定义有关,所以在我看来我应该可以指定它.
我没有找到任何参数,也没有找到任何类/接口来实现它.如何使用scikit-learn为训练有素的ExtraTreesClassifier(或任何其他人)设置阈值?
非常感谢,科利斯
fam*_*gar 15
这就是我所做的:
model = SomeSklearnModel()
model.fit(X_train, y_train)
predict = model.predict(X_test)
predict_probabilities = model.predict_proba(X_test)
fpr, tpr, _ = roc_curve(y_test, predict_probabilities)
Run Code Online (Sandbox Code Playgroud)
然而,我很生气,预测会选择相当于0.4%真阳性的阈值(误报为零).ROC曲线显示了一个阈值,我更喜欢我的问题,其中真阳性约为20%(假阳性约为4%).然后我扫描predict_probabilities以找出哪个概率值对应于我最喜欢的ROC点.在我的情况下,这个概率是0.21.然后我创建自己的预测数组:
predict_mine = np.where(rf_predict_probabilities > 0.21, 1, 0)
Run Code Online (Sandbox Code Playgroud)
然后你去:
confusion_matrix(y_test, predict_mine)
Run Code Online (Sandbox Code Playgroud)
返回我想要的东西:
array([[6927, 309],
[ 621, 121]])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
10390 次 |
| 最近记录: |