ash*_*jsu 21 random-forest cross-validation apache-spark apache-spark-ml apache-spark-mllib
我想评估正在训练某些数据的随机森林.Apache Spark中是否有任何实用程序可以执行相同操作,还是必须手动执行交叉验证?
zer*_*323 36
ML提供了CrossValidator
可用于执行交叉验证和参数搜索的类.假设您的数据已经过预处理,您可以添加交叉验证,如下所示:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ???
val nFolds: Int = ???
val numTrees: Int = ???
val metric: String = ???
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(numTrees)
val pipeline = new Pipeline().setStages(Array(rf))
val paramGrid = new ParamGridBuilder().build() // No parameter search
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
.setMetricName(metric)
val cv = new CrossValidator()
// ml.Pipeline with ml.classification.RandomForestClassifier
.setEstimator(pipeline)
// ml.evaluation.MulticlassClassificationEvaluator
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)
val model = cv.fit(trainingData) // trainingData: DataFrame
Run Code Online (Sandbox Code Playgroud)
使用PySpark:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala
pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder.
.addGrid(rf.numTrees, [3, 10])
.addGrid(...) # Add other parameters
.build())
crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=numFolds)
model = crossval.fit(trainingData)
Run Code Online (Sandbox Code Playgroud)