sklearn - 预测每个班级的概率

bmc*_*bmc 1 machine-learning probability multilabel-classification predictive

到目前为止,我已经为另一个帖子sklearn文档提供了资源

所以一般来说我想生成以下示例:

X = np.matrix([[1,2],[2,3],[3,4],[4,5]])
y = np.array(['A', 'B', 'B', 'C', 'D'])
Xt = np.matrix([[11,22],[22,33],[33,44],[44,55]])
model = model.fit(X, y)
pred = model.predict(Xt)
Run Code Online (Sandbox Code Playgroud)

但是对于输出,我希望每个观察看到3列作为输出pred:

 A  |  B  |  C
.5  | .2  | .3
.25 | .25 | .5
...
Run Code Online (Sandbox Code Playgroud)

每个班级出现在我的预测中的概率不同.

我相信最好的方法是Multilabel classification从我上面提供的第二个链接.另外,我认为跳入下面列出的一个multi-label或多个multi-output模型可能是一个好主意:

Support multilabel:

    sklearn.tree.DecisionTreeClassifier
    sklearn.tree.ExtraTreeClassifier
    sklearn.ensemble.ExtraTreesClassifier
    sklearn.neighbors.KNeighborsClassifier
    sklearn.neural_network.MLPClassifier
    sklearn.neighbors.RadiusNeighborsClassifier
    sklearn.ensemble.RandomForestClassifier
    sklearn.linear_model.RidgeClassifierCV

Support multiclass-multioutput:

    sklearn.tree.DecisionTreeClassifier
    sklearn.tree.ExtraTreeClassifier
    sklearn.ensemble.ExtraTreesClassifier
    sklearn.neighbors.KNeighborsClassifier
    sklearn.neighbors.RadiusNeighborsClassifier
    sklearn.ensemble.RandomForestClassifier
Run Code Online (Sandbox Code Playgroud)

但是,我正在寻找能够以正确的方式做到这一点的人更有信心和经验的人.所有反馈都表示赞赏.

-bmc

Mak*_*ich 8

根据我的理解,您希望获得多类分类器的每个潜在类的概率.

在Scikit-Learn中,它可以通过泛型函数predict_proba来完成.它是针对scikit-learn中的大多数分类器实现的.你基本上打电话:

clf.predict_proba(X)
Run Code Online (Sandbox Code Playgroud)

clf受过训练的分类器在哪里.作为输出,您将获得每个输入值的每个类的十进制数组.

需要注意的是 - 并非所有分类器都能自然地评估类概率.例如,SVM不这样做.您仍然可以获得类概率,但是在构造这样的分类器时,您需要指示它执行概率估计.对于SVM,它看起来像:

SVC(Probability=True)
Run Code Online (Sandbox Code Playgroud)

适合它之后,您将能够predict_proba像以前一样使用.

我需要警告你,如果分类器不能自然地评估概率,那意味着将使用相当广泛的计算方法来评估概率,这可能会显着增加训练时间.所以我建议你使用自然评估类概率的分类器(具有softmax输出的神经网络,逻辑回归,梯度增强等)

  • 您如何知道给哪个标签的概率呢?例如,y_pred = clf.predict_proba(X_test_tfidf [:len(df_test)])`产生此输出`array([[0.29354825,0.08547672,0.62097503],[0.75855171,0.13965677,0.10179152],[0.39376194,0.50768248,0.09855559], ...,[0.78636186,0.0804752,0.13316294],[0.32583947,0.06651614,0.60764439],[0.36811811,0.53192139,0.0999605]])`我怎么知道第一,第二和第三个因素代表什么? (2认同)
  • @bmc使用clf.classes_将给您正确的顺序 (2认同)