Pau*_*aul 21 modeling cross-validation pyspark apache-spark-ml apache-spark-mllib
我正在修补PySpark文档中的一些交叉验证代码,并尝试让PySpark告诉我选择了哪个模型:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.linalg import Vectors
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
dataset = sqlContext.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.regParam, [0.1, 0.01, 0.001, 0.0001]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
cvModel = cv.fit(dataset)
Run Code Online (Sandbox Code Playgroud)
在PySpark shell中运行它,我可以得到线性回归模型的系数,但我似乎无法找到lr.regParam交叉验证程序选择的值.有任何想法吗?
In [3]: cvModel.bestModel.coefficients
Out[3]: DenseVector([3.1573])
In [4]: cvModel.bestModel.explainParams()
Out[4]: ''
In [5]: cvModel.bestModel.extractParamMap()
Out[5]: {}
In [15]: cvModel.params
Out[15]: []
In [36]: cvModel.bestModel.params
Out[36]: []
Run Code Online (Sandbox Code Playgroud)
wer*_*hao 27
也遇到了这个问题.我发现你需要调用java属性由于某种原因我不知道为什么.所以这样做:
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(metricName="mae")
lr = LinearRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [500]) \
.addGrid(lr.regParam, [0]) \
.addGrid(lr.elasticNetParam, [1]) \
.build()
lr_cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, \
evaluator=evaluator, numFolds=3)
lrModel = lr_cv.fit(your_training_set_here)
bestModel = lrModel.bestModel
Run Code Online (Sandbox Code Playgroud)
打印出你想要的参数:
>>> print 'Best Param (regParam): ', bestModel._java_obj.getRegParam()
0
>>> print 'Best Param (MaxIter): ', bestModel._java_obj.getMaxIter()
500
>>> print 'Best Param (elasticNetParam): ', bestModel._java_obj.getElasticNetParam()
1
Run Code Online (Sandbox Code Playgroud)
这也适用于其他方法extractParamMap().他们应该尽快解决这个问题
这可能不如 wernerchao 答案(因为将超参数存储在变量中不方便),但您可以通过这种方式快速查看交叉验证模型的最佳超参数:
cvModel.getEstimatorParamMaps()[ np.argmax(cvModel.avgMetrics) ]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
13913 次 |
| 最近记录: |