如何在DecisionTreeClassifier中设置类别权重以进行多类别设置

MAC*_*MAC 3 python machine-learning decision-tree scikit-learn

我用来sklearn.tree.DecisionTreeClassifier训练三类分类问题。

3类中的记录数如下:

A: 122038
B: 43626
C: 6678
Run Code Online (Sandbox Code Playgroud)

当我训练分类器模型时,它无法学习类 - C。虽然效率为65-70%,但完全忽略了C类。

然后我开始了解class_weight参数,但我不确定如何在多类设置中使用它。

这是我的代码:(我使用过balanced,但它的准确性更差)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
Run Code Online (Sandbox Code Playgroud)

如何使用与类别分布成比例的权重。

其次,有没有更好的方法来解决这个不平衡类问题以提高准确性?

Jer*_*are 5

您还可以将值字典传递给 class_weight 参数以设置您自己的权重。例如,要将 A 级重量减半,您可以这样做:

class_weight={
    'A': 0.5,
    'B': 1.0,
    'C': 1.0
}
Run Code Online (Sandbox Code Playgroud)

通过执行 class_weight='balanced' 它会自动设置与类别频率成反比的权重。

更多信息可以在 class_weight 参数下的文档中找到: https ://scikit-learn.org/stable/modules/ generated/sklearn.tree.DecisionTreeClassifier.html

通常预计平衡类别会降低准确性。这就是为什么准确性通常被认为是不平衡数据集的一个糟糕的指标。

您可以尝试 sklearn 包含的平衡准确性指标作为开始,但是还有许多其他潜在的指标可以尝试,这取决于您的最终目标是什么。

https://scikit-learn.org/stable/modules/model_evaluation.html

如果您不熟悉“混淆矩阵”及其相关值(例如精度和召回率),那么我会从那里开始您的研究。

https://en.wikipedia.org/wiki/Precision_and_recall

https://en.wikipedia.org/wiki/Confusion_matrix

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html