Tim*_*Tim 5 scala apache-spark apache-spark-sql apache-spark-ml apache-spark-mllib
我一直在尝试使用成人数据集在Spark和Scala中运行示例.
使用Scala 2.11.8和Spark 1.6.1.
问题(目前)在于该数据集中的分类特征量,在Spark ML算法完成其工作之前,所有分类特征都需要编码为数字.
到目前为止我有这个:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
object Adult {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("Adult example").setMaster("local[*]")
val sparkContext = new SparkContext(conf)
val sqlContext = new SQLContext(sparkContext)
val data = sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true") // Use first line of all files as header
.option("inferSchema", "true") // Automatically infer data types
.load("src/main/resources/adult.data")
val categoricals = data.dtypes filter (_._2 == "StringType")
val encoders = categoricals map (cat => new OneHotEncoder().setInputCol(cat._1).setOutputCol(cat._1 + "_encoded"))
val features = data.dtypes filterNot (_._1 == "label") map (tuple => if(tuple._2 == "StringType") tuple._1 + "_encoded" else tuple._1)
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01)
val pipeline = new Pipeline()
.setStages(encoders ++ Array(lr))
val model = pipeline.fit(training)
}
}
Run Code Online (Sandbox Code Playgroud)
但是,这不起作用.调用pipeline.fit仍包含原始字符串功能,因此会引发异常.如何"StringType"在管道中删除这些列?或者我可能完全错了,所以如果有人有不同的建议,我很高兴所有输入:).
我选择遵循这个流程的原因是因为我在Python和Pandas中有广泛的背景,但我正在尝试学习Scala和Spark.
如果你已经习惯了更高级别的框架,那么有一点可能会让人感到困惑.您必须先索引功能,然后才能使用编码器.正如API文档中所解释的那样:
one-hot encoder(...)将一列类别索引映射到一列二进制向量,每行最多一个单值,表示输入类别索引.
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
val df = Seq((1L, "foo"), (2L, "bar")).toDF("id", "x")
val categoricals = df.dtypes.filter (_._2 == "StringType") map (_._1)
val indexers = categoricals.map (
c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)
val encoders = categoricals.map (
c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc")
)
val pipeline = new Pipeline().setStages(indexers ++ encoders)
val transformed = pipeline.fit(df).transform(df)
transformed.show
// +---+---+-----+-------------+
// | id| x|x_idx| x_enc|
// +---+---+-----+-------------+
// | 1|foo| 1.0| (1,[],[])|
// | 2|bar| 0.0|(1,[0],[1.0])|
// +---+---+-----+-------------+
Run Code Online (Sandbox Code Playgroud)
如您所见,不需要从管道中删除字符串列.在实践中OneHotEncoder会接受数字列NominalAttribute,BinaryAttribute或缺少类型属性.
| 归档时间: |
|
| 查看次数: |
3035 次 |
| 最近记录: |