如何交叉验证RandomForest模型?

ash*_*jsu 21 random-forest cross-validation apache-spark apache-spark-ml apache-spark-mllib

我想评估正在训练某些数据的随机森林.Apache Spark中是否有任何实用程序可以执行相同操作,还是必须手动执行交叉验证?

zer*_*323 36

ML提供了CrossValidator可用于执行交叉验证和参数搜索的类.假设您的数据已经过预处理,您可以添加交叉验证,如下所示:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ??? 
val nFolds: Int = ???
val numTrees: Int = ???
val metric: String = ???

val rf = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")
  .setNumTrees(numTrees)

val pipeline = new Pipeline().setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder().build() // No parameter search

val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
  .setMetricName(metric) 

val cv = new CrossValidator()
  // ml.Pipeline with ml.classification.RandomForestClassifier
  .setEstimator(pipeline)
  // ml.evaluation.MulticlassClassificationEvaluator
  .setEvaluator(evaluator) 
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(nFolds)

val model = cv.fit(trainingData) // trainingData: DataFrame
Run Code Online (Sandbox Code Playgroud)

使用PySpark:

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    

pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder. 
    .addGrid(rf.numTrees, [3, 10])
    .addGrid(...)  # Add other parameters
    .build())

crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=numFolds)

model = crossval.fit(trainingData)
Run Code Online (Sandbox Code Playgroud)

  • 不,我很确定它没有.`MLUtils.kFold`正在使用`BernoulliCellSampler`来确定分裂.另一方面,在Spark中执行留一交叉验证的成本可能要高得多,以使其在实践中可行. (5认同)