如何从交叉验证器获得训练有素的最佳模型

Fan*_* L. 6 scala machine-learning decision-tree cross-validation apache-spark

我构建了一个包含这样的DecisionTreeClassifier(dt)的管道

val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
Run Code Online (Sandbox Code Playgroud)

然后我使用这个管道作为CrossValidator中的估算器,以获得具有这样的最佳超参数集的模型

val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)
Run Code Online (Sandbox Code Playgroud)

最后,我可以使用这个交叉验证器在训练测试中训练模型

val model = c_v.fit(train)
Run Code Online (Sandbox Code Playgroud)

但问题是,我想查看受过最佳训练的决策树模型,参数.toDebugTreeDecisionTreeClassificationModel.但模型是一个CrossValidatorModel.是的,你可以使用model.bestModel,但它仍然是类型Model,你不能申请.toDebugTree它.而且我也承担bestModel仍包括pipline labelIndexer,featureIndexer,dt,labelConverter.

那么有谁知道我如何从拟合的模型中获得decisionTree模型crossvalidator,我可以通过它查看实际模型toDebugString?或者有没有可以查看decisionTree模型的解决方法?

zer*_*323 8

那么,在这样的情况下,答案总是相同的 - 具体的类型.

首先提取管道模型,因为您要训练的是管道:

import org.apache.spark.ml.PipelineModel

val bestModel: Option[PipelineModel] = model.bestModel match {
  case p: PipelineModel => Some(p)
  case _ => None
}
Run Code Online (Sandbox Code Playgroud)

然后,您需要从基础阶段提取模型.在您的情况下,它是一个决策树分类模型:

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

val treeModel: Option[DecisionTreeClassificationModel] = bestModel
  flatMap {
    _.stages.collect {
      case t: DecisionTreeClassificationModel => t
    }.headOption
  }
Run Code Online (Sandbox Code Playgroud)

要打印树,例如:

treeModel.foreach(_.toDebugString)
Run Code Online (Sandbox Code Playgroud)