scikit-learn在哪里保存树结构中每个叶节点的决策标签?

who*_*AmI 9 python tree-structure decision-tree scikit-learn

我已经使用scikit-learn训练了一个随机的森林模型,现在我想将它的树结构保存在文本文件中,以便我可以在其他地方使用它.根据这个链接,树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如,左子,右子,它检查的特征,......).但是,似乎没有关于每个叶节点对应的类标签的信息!在上面的链接中提供的示例中甚至没有提到它.

有谁知道scikit-learn决策树结构中存储的类标签在哪里?

bla*_*ite 6

看看以下文档sklearn.tree.DecisionTreeClassifier.tree_.value

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]
Run Code Online (Sandbox Code Playgroud)

clf.tree_.value“”中的每一行“包含每个节点的常数预测值”(help(clf.tree_)),它的索引到索引对应于clf.classes_

有关更多详情,请参见此答案

  • 添加到答案中,对于此数组中的每一行,您都可以执行“ clf.classes_ [np.argmax(value)]”来获取预测的类标签。 (3认同)