如何探索使用scikit学习构建的决策树

ele*_*ora 11 python machine-learning decision-tree scikit-learn

我正在使用构建决策树

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)
Run Code Online (Sandbox Code Playgroud)

一切正常.但是,我如何探索决策树?

例如,如何找到X_train中的哪些条目出现在特定的叶子中?

Pab*_*rre 13

您需要使用预测方法.

训练树后,您可以提供X值以预测其输出.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data) 
Run Code Online (Sandbox Code Playgroud)

输出:

>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
Run Code Online (Sandbox Code Playgroud)

要获得有关树结构的详细信息,我们可以使用 tree_.__getstate__()

树形结构翻译成"ASCII艺术"图片

              0  
        _____________
        1           2
               ______________
               3            12
            _______      _______
            4     7      13   16
           ___   ______        _____
           5 6   8    9        14 15
                      _____
                      10 11
Run Code Online (Sandbox Code Playgroud)

树结构作为一个数组.

In [38]: tree.tree_.__getstate__()['nodes']
Out[38]: 
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
       (-1, -1, -2, -2.0, 0.0, 50, 50.0),
       (3, 12, 3, 1.75, 0.5, 100, 100.0),
       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
       (-1, -1, -2, -2.0, 0.0, 47, 47.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
       (-1, -1, -2, -2.0, 0.0, 3, 3.0),
       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (-1, -1, -2, -2.0, 0.0, 43, 43.0)], 
      dtype=[('left_child', '<i8'), ('right_child', '<i8'), 
             ('feature', '<i8'), ('threshold', '<f8'), 
             ('impurity', '<f8'), ('n_node_samples', '<i8'), 
             ('weighted_n_node_samples', '<f8')])
Run Code Online (Sandbox Code Playgroud)

哪里:

  • 第一个节点[0]是根节点.
  • 内部节点有left_child和right_child引用具有正值且大于当前节点的节点.
  • 叶子具有左右子节点的-​​1值.
  • 节点1,5,6,8,10,11,14,15,16是叶子.
  • 节点结构使用深度优先搜索算法构建.
  • 功能字段告诉我们在节点中使用了哪个iris.data功能来确定此样本的路径.
  • 阈值告诉我们用于根据特征评估方向的值.
  • 叶子上的杂质达到0 ...因为所有样品一旦到达叶子就属于同一类.
  • n_node_samples告诉我们有多少样本到达每个叶子.

使用这些信息,我们可以通过遵循脚本上的分类规则和阈值,轻松地跟踪每个样本X到最终着陆的叶子.另外,n_node_samples允许我们执行单元测试,确保每个节点获得正确数量的样本.然后使用tree.predict的输出,我们可以将每个叶子映射到关联的类.


zax*_*liu 5

注意:这不是答案,只是提示可能的解决方案.

我最近在项目中遇到了类似的问题.我的目标是为某些特定样本提取相应的决策链.我认为你的问题是我的一个子集,因为你只需要记录决策链的最后一步.

到目前为止,似乎唯一可行的解​​决方案是在Python中编写自定义predict方法以跟踪整个过程中的决策.原因是predictscikit-learn提供的方法无法开箱即用(据我所知).更糟糕的是,它是C实现的包装器,很难定制.

定制对我的问题很好,因为我正在处理不平衡的数据集,而我关心的样本(正面的样本)很少见.所以我可以先使用sklearn过滤掉它们predict,然后使用我的自定义来获取决策链.

但是,如果您有大型数据集,这可能对您不起作用.因为如果您解析树并在Python中进行预测,它将在Python速度上运行缓慢并且不会(轻松)扩展.您可能不得不回退自定义C实现.


Cha*_*ley 3

下面的代码应该生成前十个功能的图:

import numpy as np
import matplotlib.pyplot as plt

importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(10):
    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()
Run Code Online (Sandbox Code Playgroud)

取自此处并稍加修改以适应DecisionTreeClassifier

这并不能完全帮助您探索这棵树,但它确实告诉您有关这棵树的信息。