sklearn.ensemble.AdaBoostClassifier不能将SVM作为base_estimator加入?

all*_*ang 14 python machine-learning scikit-learn ensemble-learning

我正在做一个文本分类任务.现在我想ensemble.AdaBoostClassifierLinearSVCas base_estimator.但是,当我尝试运行代码时

clf = AdaBoostClassifier(svm.LinearSVC(),n_estimators=50, learning_rate=1.0,    algorithm='SAMME.R')
clf.fit(X, y)
Run Code Online (Sandbox Code Playgroud)

发生错误. TypeError: AdaBoostClassifier with algorithm='SAMME.R' requires that the weak learner supports the calculation of class probabilities with a predict_proba method

第一个问题是无法svm.LinearSVC()计算类概率?如何计算概率?

然后我更改参数algorithm并再次运行代码.

clf = AdaBoostClassifier(svm.LinearSVC(),n_estimators=50, learning_rate=1.0, algorithm='SAMME')
clf.fit(X, y)
Run Code Online (Sandbox Code Playgroud)

这一次TypeError: fit() got an unexpected keyword argument 'sample_weight'发生了.正如在AdaBoostClassifier中所说,Sample weights. If None, the sample weights are initialized to 1 / n_samples.即使我分配了一个整数n_samples,也会发生错误.

第二个问题是什么n_samples意思?如何解决这个问题呢?

希望有人能帮助我.

然而,根据@jme的评论,经过尝试

clf = AdaBoostClassifier(svm.SVC(kernel='linear',probability=True),n_estimators=10,  learning_rate=1.0, algorithm='SAMME.R')
clf.fit(X, y)
Run Code Online (Sandbox Code Playgroud)

程序无法获得结果,服务器上使用的内存保持不变.

第三个问题是我如何AdaBoostClassifier使用SVCbase_estimator 进行工作?

小智 13

正确的答案将取决于您正在寻找什么.LinearSVC无法预测类概率(AdaBoostClassifier使用的默认算法所需)并且不支持sample_weight.

您应该知道支持向量机不会在名义上预测类概率.它们是使用Platt缩放(或多类情况下Platt缩放的扩展)计算的,这是一种已知问题的技术.如果您需要较少的"人工"类概率,SVM可能不是最佳选择.

话虽如此,我相信给出你的问题最令人满意的答案是格雷厄姆给出的.那是,

from sklearn.svm import SVC
from sklearn.ensemble import AdaBoostClassifier

clf = AdaBoostClassifier(SVC(probability=True, kernel='linear'), ...)
Run Code Online (Sandbox Code Playgroud)

你还有其他选择.您可以将SGDClassifier与铰链损失函数一起使用,并将AdaBoostClassifier设置为使用SAMME算法(不需要predict_proba函数,但需要支持sample_weight):

from sklearn.linear_model import SGDClassifier

clf = AdaBoostClassifier(SGDClassifier(loss='hinge'), algorithm='SAMME', ...)
Run Code Online (Sandbox Code Playgroud)

如果您想使用为AdaBoostClassifier提供的默认算法,最好的答案可能是使用对类概率具有本机支持的分类器,如Logistic回归.您可以使用scikit.linear_model.LogisticRegression或使用具有日志丢失功能的SGDClassifier来执行此操作,如Kris提供的代码中所使用的那样.

希望有所帮助,如果您对Platt缩放的内容感到好奇,请查看John Platt的原始论文.