将scikit-learn DecisionTreeClassifier.tree_.value映射到预测类

nem*_*emi 6 python decision-tree scikit-learn

我在3类数据集上使用scikit-learn DecissionTreeClassifier.在我拟合分类器后,我访问tree_属性上的所有叶节点,以获得最终在每个类的给定节点中的实例数量.

clf = tree.DecisionTreeClassifier(max_depth=5)
clf.fit(X, y)
# lets assume there is a leaf node with id 5
print clf.tree_.value[5]
Run Code Online (Sandbox Code Playgroud)

这将打印出来:

>>> array([[  0.,   1.,  68.]])
Run Code Online (Sandbox Code Playgroud)

但是...我怎么知道该数组中哪个位置属于哪个类?分类器具有classes_属性,该属性也是列表

>>> clf.classes_
array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)
Run Code Online (Sandbox Code Playgroud)

也许value数组上的索引1匹配classes数组的索引1上的类,依此类推?

nem*_*emi 7

Asked about this on the scikit-learm mailing list and my guess was right. Turns out the index 1 on the value array matches the class on index 1 of the classes array and so on