MAC*_*MAC 3 python machine-learning decision-tree scikit-learn
我用来sklearn.tree.DecisionTreeClassifier
训练三类分类问题。
3类中的记录数如下:
A: 122038
B: 43626
C: 6678
Run Code Online (Sandbox Code Playgroud)
当我训练分类器模型时,它无法学习类 - C
。虽然效率为65-70%,但完全忽略了C类。
然后我开始了解class_weight
参数,但我不确定如何在多类设置中使用它。
这是我的代码:(我使用过balanced
,但它的准确性更差)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
Run Code Online (Sandbox Code Playgroud)
如何使用与类别分布成比例的权重。
其次,有没有更好的方法来解决这个不平衡类问题以提高准确性?
您还可以将值字典传递给 class_weight 参数以设置您自己的权重。例如,要将 A 级重量减半,您可以这样做:
class_weight={
'A': 0.5,
'B': 1.0,
'C': 1.0
}
Run Code Online (Sandbox Code Playgroud)
通过执行 class_weight='balanced' 它会自动设置与类别频率成反比的权重。
更多信息可以在 class_weight 参数下的文档中找到: https ://scikit-learn.org/stable/modules/ generated/sklearn.tree.DecisionTreeClassifier.html
通常预计平衡类别会降低准确性。这就是为什么准确性通常被认为是不平衡数据集的一个糟糕的指标。
您可以尝试 sklearn 包含的平衡准确性指标作为开始,但是还有许多其他潜在的指标可以尝试,这取决于您的最终目标是什么。
https://scikit-learn.org/stable/modules/model_evaluation.html
如果您不熟悉“混淆矩阵”及其相关值(例如精度和召回率),那么我会从那里开始您的研究。
https://en.wikipedia.org/wiki/Precision_and_recall
https://en.wikipedia.org/wiki/Confusion_matrix
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
归档时间: |
|
查看次数: |
8429 次 |
最近记录: |