sta*_*010 2 scala apache-spark apache-spark-ml
我正在尝试用随机森林分类器创建一个Spark ML Pipeline来执行分类(而不是回归),但是我收到一个错误,说我的训练集中的预测标签应该是double而不是整数.我按照这些页面的说明操作:
" 分类和回归 - spark.ml "(apache.org)
" 如何为Spark ML中的分类创建正确的数据框 "(stack overflow.com)
" Spark MLLib - 用ML管道预测商店销售 "(sparktutorials.net)
我有一个包含以下列的Spark数据帧:
scala> df.show(5)
+-------+----------+----------+---------+-----+
| userId|duration60|duration30|duration1|label|
+-------+----------+----------+---------+-----+
|user000| 11| 21| 35| 3|
|user001| 28| 41| 28| 4|
|user002| 17| 6| 8| 2|
|user003| 39| 29| 0| 1|
|user004| 26| 23| 25| 3|
+-------+----------+----------+---------+-----+
scala> df.printSchema()
root
|-- userId: string (nullable = true)
|-- duration60: integer (nullable = true)
|-- duration30: integer (nullable = true)
|-- duration1: integer (nullable = true)
|-- label: integer (nullable = true)
Run Code Online (Sandbox Code Playgroud)
我使用功能列duration60,duration30和duration1来预测分类列标签.
然后我像这样设置我的Spark脚本:
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.{Pipeline, PipelineModel}
Logger.getLogger("org").setLevel(Level.ERROR)
Logger.getLogger("akka").setLevel(Level.ERROR)
val sqlContext = new SQLContext(sc)
val df = sqlContext.read.
format("com.databricks.spark.csv").
option("header", "true"). // Use first line of all files as header
option("inferSchema", "true"). // Automatically infer data types
load("/tmp/features.csv").
withColumnRenamed("satisfaction", "label").
select("userId", "duration60", "duration30", "duration1", "label")
val assembler = new VectorAssembler().
setInputCols(Array("duration60", "duration30", "duration1")).
setOutputCol("features")
val randomForest = new RandomForestClassifier().
setLabelCol("label").
setFeaturesCol("features").
setNumTrees(10)
var pipeline = new Pipeline().setStages(Array(assembler, randomForest))
var model = pipeline.fit(df);
Run Code Online (Sandbox Code Playgroud)
转换后的数据帧如下:
scala> assembler.transform(df).show(5)
+-------+----------+----------+---------+-----+----------------+
| userId|duration60|duration30|duration1|label| features|
+-------+----------+----------+---------+-----+----------------+
|user000| 11| 21| 35| 3|[11.0,21.0,35.0]|
|user001| 28| 41| 28| 4|[28.0,41.0,28.0]|
|user002| 17| 6| 8| 2| [17.0,6.0,8.0]|
|user003| 39| 29| 0| 1| [39.0,29.0,0.0]|
|user004| 26| 23| 25| 3|[26.0,23.0,25.0]|
+-------+----------+----------+---------+-----+----------------+
Run Code Online (Sandbox Code Playgroud)
但是,最后一行抛出异常:
java.lang.IllegalArgumentException:要求失败:列标签必须是DoubleType类型,但实际上是IntegerType.
这是什么意思,我该如何解决?
为什么label列必须是双倍的?我在做预测,而不是回归,所以我认为字符串或整数是正确的.预测列的双精度值通常意味着回归.
做cast DoubleType自认为是算法所需的类型.
import org.apache.spark.sql.types._
df.withColumn("label", 'label cast DoubleType)
Run Code Online (Sandbox Code Playgroud)
因此,就在您val df进入应用程序之前,在序列的最后一行进行转换:
import org.apache.spark.sql.types._
val df = sqlContext.read.
format("com.databricks.spark.csv").
option("header", "true"). // Use first line of all files as header
option("inferSchema", "true"). // Automatically infer data types
load("/tmp/features.csv").
withColumnRenamed("satisfaction", "label").
select("userId", "duration60", "duration30", "duration1", "label")
.withColumn("label", 'label cast DoubleType) // <-- HERE
Run Code Online (Sandbox Code Playgroud)
请注意,我使用了'label符号(单引号'后跟一个名称)来引用该列label(我可能也使用$"label"或col("label")或df("label")或column("label")).