Moh*_*mad 27 pipeline scala cross-validation apache-spark apache-spark-mllib
我想ParamGridBuilder在Spark 1.4.x中找到CrossValidator中最佳模型的参数,
在Spark文档中的Pipeline示例中,它们通过在管道中使用来添加不同的参数(numFeatures,regParam)ParamGridBuilder.然后通过以下代码行创建最佳模型:
val cvModel = crossval.fit(training.toDF)
Run Code Online (Sandbox Code Playgroud)
现在,我想知道从中产生最佳模型的参数(numFeatures,regParam)是什么ParamGridBuilder.
我已经使用了以下命令但没有成功:
cvModel.bestModel.extractParamMap().toString()
cvModel.params.toList.mkString("(", ",", ")")
cvModel.estimatorParamMaps.toString()
cvModel.explainParams()
cvModel.getEstimatorParamMaps.mkString("(", ",", ")")
cvModel.toString()
Run Code Online (Sandbox Code Playgroud)
有帮助吗?
提前致谢,
小智 17
获取正确ParamMap对象的一种方法是使用CrossValidatorModel.avgMetrics: Array[Double]查找argmax ParamMap:
implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {
def bestEstimatorParamMap: ParamMap = {
cvModel.getEstimatorParamMaps
.zip(cvModel.avgMetrics)
.maxBy(_._2)
._1
}
}
Run Code Online (Sandbox Code Playgroud)
当CrossValidatorModel您在受管道示例中训练时运行时,您引用了:
scala> println(cvModel.bestEstimatorParamMap)
{
hashingTF_2b0b8ccaeeec-numFeatures: 100,
logreg_950a13184247-regParam: 0.1
}
Run Code Online (Sandbox Code Playgroud)
小智 12
val bestPipelineModel = cvModel.bestModel.asInstanceOf[PipelineModel]
val stages = bestPipelineModel.stages
val hashingStage = stages(1).asInstanceOf[HashingTF]
println("numFeatures = " + hashingStage.getNumFeatures)
val lrStage = stages(2).asInstanceOf[LogisticRegressionModel]
println("regParam = " + lrStage.getRegParam)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
11806 次 |
| 最近记录: |