小编M. *_*eto的帖子

如何访问Spark PipelineModel参数

我在pyspark中使用Spark Pipelines运行线性回归.一旦训练了线性回归模型,我如何得出系数?

这是我的管道代码:

# Get all of our features together into one array called "features".  Do not include the label!
feature_assembler = VectorAssembler(inputCols=get_column_names(df_train), outputCol="features")

# Define our model
lr = LinearRegression(maxIter=100, elasticNetParam=0.80, labelCol="label", featuresCol="features", 
                  predictionCol = "prediction")

# Define our pipeline
pipeline_baseline = Pipeline(stages=[feature_assembler, lr])

# Train our model using the training data
model_baseline = pipeline_baseline.fit(df_train)

# Use our trained model to make predictions using the validation data
output_baseline = model_baseline.transform(df_val)  #.select("features", "label", "prediction", "coefficients")
predictions_baseline = output_baseline.select("label", …
Run Code Online (Sandbox Code Playgroud)

python apache-spark pyspark pyspark-sql apache-spark-ml

5
推荐指数
1
解决办法
2646
查看次数

Spark ML Pipeline导致java.lang.Exception:编译失败...代码...增长超过64 KB

使用Spark 2.0,我试图在pyspark ML管道中运行一个简单的VectorAssembler,如下所示:

feature_assembler = VectorAssembler(inputCols=['category_count', 'name_count'], \
                                    outputCol="features") 
pipeline = Pipeline(stages=[feature_assembler])
model = pipeline.fit(df_train)
model_output = model.transform(df_train)
Run Code Online (Sandbox Code Playgroud)

当我尝试使用时查看输出

model_output.select("features").show(1)
Run Code Online (Sandbox Code Playgroud)

我收到了错误

Py4JJavaError                             Traceback (most recent call last)
<ipython-input-95-7a3e3d4f281c> in <module>()
      2 
      3 
----> 4 model_output.select("features").show(1)

/usr/local/spark20/python/pyspark/sql/dataframe.pyc in show(self, n, truncate)
    285         +---+-----+
    286         """
--> 287         print(self._jdf.showString(n, truncate))
    288 
    289     def __repr__(self):

/usr/local/spark20/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py in __call__(self, *args)
    931         answer = self.gateway_client.send_command(command)
    932         return_value = get_return_value(
--> 933             answer, self.gateway_client, self.target_id, self.name)
    934 
    935         for temp_arg in temp_args:

/usr/local/spark20/python/pyspark/sql/utils.pyc in deco(*a, **kw) …
Run Code Online (Sandbox Code Playgroud)

python apache-spark apache-spark-sql pyspark pyspark-sql

5
推荐指数
1
解决办法
833
查看次数