如何从CrossValidatorModel中提取最佳参数

Moh*_*mad 27 pipeline scala cross-validation apache-spark apache-spark-mllib

我想ParamGridBuilder在Spark 1.4.x中找到CrossValidator中最佳模型的参数,

在Spark文档中的Pipeline示例中,它们通过在管道中使用来添加不同的参数(numFeatures,regParam)ParamGridBuilder.然后通过以下代码行创建最佳模型:

val cvModel = crossval.fit(training.toDF)
Run Code Online (Sandbox Code Playgroud)

现在,我想知道从中产生最佳模型的参数(numFeatures,regParam)是什么ParamGridBuilder.

我已经使用了以下命令但没有成功:

cvModel.bestModel.extractParamMap().toString()
cvModel.params.toList.mkString("(", ",", ")")
cvModel.estimatorParamMaps.toString()
cvModel.explainParams()
cvModel.getEstimatorParamMaps.mkString("(", ",", ")")
cvModel.toString()
Run Code Online (Sandbox Code Playgroud)

有帮助吗?

提前致谢,

小智 17

获取正确ParamMap对象的一种方法是使用CrossValidatorModel.avgMetrics: Array[Double]查找argmax ParamMap:

implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {
  def bestEstimatorParamMap: ParamMap = {
    cvModel.getEstimatorParamMaps
           .zip(cvModel.avgMetrics)
           .maxBy(_._2)
           ._1
  }
}
Run Code Online (Sandbox Code Playgroud)

CrossValidatorModel您在受管道示例中训练时运行时,您引用了:

scala> println(cvModel.bestEstimatorParamMap)
{
   hashingTF_2b0b8ccaeeec-numFeatures: 100,
   logreg_950a13184247-regParam: 0.1
}
Run Code Online (Sandbox Code Playgroud)

  • 注意:`maxBy`可能需要是`minBy`,具体取决于`Evaluator.isLargerBetter`的值. (6认同)

小智 12

val bestPipelineModel = cvModel.bestModel.asInstanceOf[PipelineModel]
val stages = bestPipelineModel.stages

val hashingStage = stages(1).asInstanceOf[HashingTF]
println("numFeatures = " + hashingStage.getNumFeatures)

val lrStage = stages(2).asInstanceOf[LogisticRegressionModel]
println("regParam = " + lrStage.getRegParam)
Run Code Online (Sandbox Code Playgroud)

资源