使用scikit-learn时,如何找到树分裂的属性?

Mat*_*ien 42 python machine-learning decision-tree scikit-learn

我一直在探索scikit-learn,使用熵和基尼分裂标准制定决策树,并探索差异.

我的问题是,我如何"打开引擎盖"并确切地找出树在每个级别上分裂的属性及其相关的信息值,以便我可以看到这两个标准在哪里做出不同的选择?

到目前为止,我已经探索了文档中概述的9种方法.它们似乎不允许访问此信息.但是这些信息肯定是可以访问的吗?我正在设想一个包含节点和增益条目的列表或字典.

如果我错过了一些完全明显的东西,感谢您的帮助和道歉.

lej*_*lot 34

直接来自文档(http://scikit-learn.org/0.12/modules/tree.html):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
Run Code Online (Sandbox Code Playgroud)

StringIOPython3不再支持io模块,而是导入模块.

tree_决策树对象中还有属性,允许直接访问整个结构.

你可以简单地阅读它

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values
Run Code Online (Sandbox Code Playgroud)

有关更多详细信息,请查看导出方法源代码

通常,您可以使用该inspect模块

from inspect import getmembers
print( getmembers( clf.tree_ ) )
Run Code Online (Sandbox Code Playgroud)

获取所有对象的元素

来自sklearn docs的决策树可视化

  • "左"是否总是"真实"值,右边是"假"? (7认同)

Dan*_*son 11

如果您只想快速查看树中发生的情况,请尝试:

zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)
Run Code Online (Sandbox Code Playgroud)

其中X是自变量的数据框,clf是决策树对象.请注意,clf.tree_.children_left并且clf.tree_.children_right一起包含该分割作了(这些每一个将对应于在graphviz的可视化中的箭头)的顺序.


小智 8

Scikit learnexport_text在 0.21 版本(2019 年 5 月)中引入了一种美味的新方法,可以从树中查看所有规则。文档在这里

拟合模型后,您只需要两行代码。首先,导入export_text

from sklearn.tree.export import export_text
Run Code Online (Sandbox Code Playgroud)

其次,创建一个包含您的规则的对象。为了使规则看起来更具可读性,请使用feature_names参数并传递您的功能名称列表。例如,如果您的模型被调用model并且您的特征在名为 的数据框中命名X_train,则您可以创建一个名为 的对象tree_rules

tree_rules = export_text(model, feature_names=list(X_train))
Run Code Online (Sandbox Code Playgroud)

然后只需打印或保存tree_rules。您的输出将如下所示:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
Run Code Online (Sandbox Code Playgroud)

  • 请注意:`FutureWarning:sklearn.tree.export 模块在 0.22 版本中已弃用,并将在 0.24 版本中删除。相应的类/函数应该从 sklearn.tree 导入。任何无法从 sklearn.tree 导入的内容现在都是私有 API 的一部分。 (3认同)