在Spark中获取树模型的叶子概率

nic*_*ola 6 apache-spark pyspark apache-spark-ml

我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),使其可以在没有火花的环境中导出。该toDebugString方法是一个很好的起点。但是,对于RandomForestClassifier,字符串仅显示每棵树的预测类,而没有相对概率。因此,如果对所有树木的预测取平均值,则会得到错误的结果。

一个例子。我们DecisionTree以这种方式代表:

DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
  If (feature 21 in {1.0})
   Predict: 0.0
  Else (feature 21 not in {1.0})
   If (feature 10 in {0.0})
    Predict: 0.0
   Else (feature 10 not in {0.0})
    Predict: 1.0
Run Code Online (Sandbox Code Playgroud)

如我们所见,跟随这些节点,看起来预测总是为0或1。但是,如果将这棵单树应用于特征向量,则得到的概率像[0.1007, 0.8993],并且它们在训练中非常有意义,因为在训练中设置负数/正数的比例,该比例最终与示例矢量与输出概率匹配的位置相同。

我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,怎么办?一个pyspark解决方案是更好的。

104*_*ica 5

我正在尝试重构经过训练的基于 Spark 树的模型(RandomForest 或 GBT 分类器),使其可以在没有 Spark 的环境中导出。这

鉴于为 Spark(和其他)模型的实时服务而设计的工具数量不断增加,这可能会重新发明轮子。

但是,如果您想从普通 Python 访问模型内​​部,最好加载其序列化形式。

假设您有:

from pyspark.ml.classification import RandomForestClassificationModel

rf_model: RandomForestClassificationModel
path: str  # Absolute path
Run Code Online (Sandbox Code Playgroud)

然后保存模型:

rf_model.write().save(path)
Run Code Online (Sandbox Code Playgroud)

您可以使用支持结构和列表类型混合的 Parquet 读取器将其加载回来。模型编写器写入两个节点数据:

rf_model.write().save(path)
Run Code Online (Sandbox Code Playgroud)
node_data = spark.read.parquet("{}/data".format(path))

node_data.printSchema()
Run Code Online (Sandbox Code Playgroud)

和树元数据:

root
 |-- treeID: integer (nullable = true)
 |-- nodeData: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- prediction: double (nullable = true)
 |    |-- impurity: double (nullable = true)
 |    |-- impurityStats: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- rawCount: long (nullable = true)
 |    |-- gain: double (nullable = true)
 |    |-- leftChild: integer (nullable = true)
 |    |-- rightChild: integer (nullable = true)
 |    |-- split: struct (nullable = true)
 |    |    |-- featureIndex: integer (nullable = true)
 |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |    |-- element: double (containsNull = true)
 |    |    |-- numCategories: integer (nullable = true)
Run Code Online (Sandbox Code Playgroud)
tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
Run Code Online (Sandbox Code Playgroud)

前一个提供了您需要的所有信息,因为预测过程基本上是以下内容的聚合impurtityStats .

您还可以使用底层 Java 对象直接访问此数据

tree_meta.printSchema()                            
root
 |-- treeID: integer (nullable = true)
 |-- metadata: string (nullable = true)
 |-- weights: double (nullable = true)
Run Code Online (Sandbox Code Playgroud)

可以RandomForestModel这样应用:

nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
Run Code Online (Sandbox Code Playgroud)

此外,这样的结构可以很容易地用于对两个单独的树进行预测(警告:Python 3.7+ 提前。对于遗留用法,请参阅functools文档):

from  collections import namedtuple
import numpy as np

LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
    "InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))

def jtree_to_python(jtree):
    def jsplit_to_python(jsplit):
        if jsplit.getClass().toString().endswith(".ContinuousSplit"):
            return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
        else:
            jcat = jsplit.toOld().categories()
            return CategoricalSplit(
                jsplit.featureIndex(),
                [jcat.apply(i) for i in range(jcat.length())])

    def jnode_to_python(jnode):
        prediction = jnode.prediction()        
        stats = np.array(list(jnode.impurityStats().stats()))

        if jnode.numDescendants() != 0:  # InternalNode
            left = jnode_to_python(jnode.leftChild())
            right = jnode_to_python(jnode.rightChild())
            split = jsplit_to_python(jnode.split())

            return InternalNode(left, right, prediction, stats, split)            

        else:
            return LeafNode(prediction, stats) 

    return jnode_to_python(jtree.rootNode())
Run Code Online (Sandbox Code Playgroud)

和森林:

nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
Run Code Online (Sandbox Code Playgroud)

然而,这取决于内部 API(以及 Scala 包范围访问修饰符的弱点),并且将来可能会崩溃。


*DataFrame从路径加载data可以轻松转换为与predict上面predict_probability定义的函数兼容的结构。

from functools import singledispatch

@singledispatch
def should_go_left(split, vector): pass

@should_go_left.register
def _(split: CategoricalSplit, vector):
    return vector[split.feature_index] in split.categories

@should_go_left.register
def _(split: ContinuousSplit, vector):
    return vector[split.feature_index] <= split.threshold

@singledispatch
def predict(node, vector): pass

@predict.register
def _(node: LeafNode, vector):
    return node.prediction, node.impurity

@predict.register
def _(node: InternalNode, vector):
    return predict(
        node.left if should_go_left(node.split, vector) else node.right,
        vector
    )
Run Code Online (Sandbox Code Playgroud)