如何从scikit-learn解释决策树

Stu*_*ack 23 python numpy scipy decision-tree scikit-learn

从scikit-learn了解决策树的结果我有两个问题.例如,这是我的决策树之一:

在此输入图像描述 我的问题是我如何使用树?

第一个问题是:如果样本满足条件,那么它将进入LEFT分支(如果存在),否则它将变为正确.在我的情况下,如果样本有X [7]> 63521.3984.然后样品将进入绿色框.正确?

第二个问题是:当样本到达叶节点时,我怎么知道它属于哪个类别?在这个例子中,我有三个类别要分类.在红色框中,分别满足91,212和113个样本的条件.但是我该如何确定类别呢?我知道有一个函数 clf.predict(样本)来告诉类别.我可以从图表中做到吗??? 非常感谢.

Bre*_*arn 26

value每个框中的行按顺序告诉您该节点上有多少样本属于每个类别.这就是为什么在每个方框中,数字value加起来显示的数字sample.例如,在您的红色框中,91 + 212 + 113 = 416.因此,这意味着如果到达此节点,类别1中有91个数据点,类别2中有212个数据点,类别3中有113个数据点.

如果您要预测在决策树中到达该叶子的新数据点的结果,则可以预测类别2,因为这是该节点上样本的最常见类别.


Myo*_*age 5

第一个问题: 是的,您的逻辑是正确的。左节点为True,右节点为False。这可能是违反直觉的;true可以等同于较小的样本。

第二个问题: 通过使用pydotplus将树可视化为图形,可以最好地解决此问题。tree.export_graphviz()的'class_names'属性将为每个节点的多数类添加一个类声明。代码在iPython笔记本中执行。

from sklearn.datasets import load_iris  
from sklearn import tree  
iris = load_iris()  
clf2 = tree.DecisionTreeClassifier()  
clf2 = clf2.fit(iris.data, iris.target)  

with open("iris.dot", 'w') as f:  
    f = tree.export_graphviz(clf, out_file=f)  

import os  
os.unlink('iris.dot')  

import pydotplus  
dot_data = tree.export_graphviz(clf2, out_file=None)  
graph2 = pydotplus.graph_from_dot_data(dot_data)  
graph2.write_pdf("iris.pdf")  

from IPython.display import Image  
dot_data = tree.export_graphviz(clf2, out_file=None,  
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  # leaves_parallel=True, 
                     special_characters=True)  
graph2 = pydotplus.graph_from_dot_data(dot_data)

## Color of nodes
nodes = graph2.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
        color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
        values = color[values.index(max(values))]; # print(values)
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
        node.set_fillcolor(color )
#

Image(graph2.create_png() ) 
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

至于确定叶子上的类,您的示例没有像虹膜数据集那样具有单个类的叶子。这很常见,可能需要过度拟合模型才能获得这种结果。对于许多交叉验证的模型,类的离散分布是最好的结果。

享受代码!