scikit-learn DecisionTreeClassifier.tree_.value有什么作用?

Ous*_*bri 1 python machine-learning decision-tree scikit-learn

我正在研究DecisionTreeClassifier模型,我想了解模型选择的路径.所以我需要知道什么价值给了

DecisionTreeClassifier.tree_.value
Run Code Online (Sandbox Code Playgroud)

谢谢,

des*_*aut 6

嗯,你是正确的,文档实际上是模糊的(但说实话,我也不确定它的用处).

让我们用虹膜数据复制文档中的示例:

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
Run Code Online (Sandbox Code Playgroud)

要求clf.tree_.value,我们得到:

array([[[ 50.,  50.,  50.]],
       [[ 50.,   0.,   0.]],
       [[  0.,  50.,  50.]],
       [[  0.,  49.,   5.]],
       [[  0.,  47.,   1.]],
       [[  0.,  47.,   0.]],
       [[  0.,   0.,   1.]],
       [[  0.,   2.,   4.]],
       [[  0.,   0.,   3.]],
       [[  0.,   2.,   1.]],
       [[  0.,   2.,   0.]],
       [[  0.,   0.,   1.]],
       [[  0.,   1.,  45.]],
       [[  0.,   1.,   2.]],
       [[  0.,   1.,   0.]],
       [[  0.,   0.,   2.]],
       [[  0.,   0.,  43.]]])
Run Code Online (Sandbox Code Playgroud)

len(clf.tree_.value)
# 17
Run Code Online (Sandbox Code Playgroud)

要了解这个数组究竟代表什么,查看树形可视化非常有用(也可以在文档中找到,为方便起见,在此处复制):

在此输入图像描述

我们可以看到,树有17个节点; 仔细观察,我们发现value每个节点实际上是我们clf.tree_.value数组的一个元素.

所以,长话短说:

  • clf.tree_.value 是一个数组数组,长度等于树中的节点数
  • 它的每个元素数组(对应一个树节点)的长度等于类的数量(这里是3)
  • 这些3元素阵列中的每一个对应于最终在每个类的相应节点中的训练样本的量.

为了澄清最后一点的例子,考虑数组的第二个元素[[ 50., 0., 0.]](对应于橙色节点):它表示,在这个节点中,最终得到来自类#0的50个样本,并且零样本来自其他两个班级(#1和#2).

希望这可以帮助...