mkr*_*sel 5 scala apache-spark apache-spark-ml
我在Spark中有一个RDD,其中对象基于案例类:
ExampleCaseClass(user: User, stuff: Stuff)
Run Code Online (Sandbox Code Playgroud)
我想使用Spark的ML管道,因此我将其转换为Spark数据帧。作为管道的一部分,我想将其中一列转换为条目为向量的列。由于我希望向量的长度随模型而变化,因此应将其内置到管道中,作为特征转换的一部分。
因此,我尝试按以下方式定义一个Transformer:
class MyTransformer extends Transformer {
val uid = ""
val num: IntParam = new IntParam(this, "", "")
def setNum(value: Int): this.type = set(num, value)
setDefault(num -> 50)
def transform(df: DataFrame): DataFrame = {
...
}
def transformSchema(schema: StructType): StructType = {
val inputFields = schema.fields
StructType(inputFields :+ StructField("colName", ???, true))
}
def copy (extra: ParamMap): Transformer = defaultCopy(extra)
}
Run Code Online (Sandbox Code Playgroud)
如何指定结果字段的数据类型(即填写???)?它是一些简单类(布尔,整数,双精度型等)的向量。看来VectorUDT可能起作用了,但这对Spark来说是私有的。由于任何RDD都可以转换为DataFrame,因此任何案例类都可以转换为自定义DataType。但是我无法弄清楚如何手动执行此转换,否则我可以将其应用于包装矢量的一些简单案例类。
此外,如果我为列指定矢量类型,当我去拟合模型时,VectorAssembler是否可以将矢量正确处理为单独的特征?
还是Spark的新手,尤其是ML Pipeline,所以不胜感激。
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
def transformSchema(schema: StructType): StructType = {
val inputFields = schema.fields
StructType(inputFields :+ StructField("colName", VectorType, true))
}
Run Code Online (Sandbox Code Playgroud)
在spark 2.1中,VectorType使VectorUDT公开可用:
package org.apache.spark.ml.linalg
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.sql.types.DataType
/**
* :: DeveloperApi ::
* SQL data types for vectors and matrices.
*/
@Since("2.0.0")
@DeveloperApi
object SQLDataTypes {
/** Data type for [[Vector]]. */
val VectorType: DataType = new VectorUDT
/** Data type for [[Matrix]]. */
val MatrixType: DataType = new MatrixUDT
}
Run Code Online (Sandbox Code Playgroud)
import org.apache.spark.mllib.linalg.{Vector, Vectors}
case class MyVector(vector: Vector)
val vectorDF = Seq(
MyVector(Vectors.dense(1.0,3.4,4.4)),
MyVector(Vectors.dense(5.5,6.7))
).toDF
vectorDF.printSchema
root
|-- vector: vector (nullable = true)
println(vectorDF.schema.fields(0).dataType.prettyJson)
{
"type" : "udt",
"class" : "org.apache.spark.mllib.linalg.VectorUDT",
"pyClass" : "pyspark.mllib.linalg.VectorUDT",
"sqlType" : {
"type" : "struct",
"fields" : [ {
"name" : "type",
"type" : "byte",
"nullable" : false,
"metadata" : { }
}, {
"name" : "size",
"type" : "integer",
"nullable" : true,
"metadata" : { }
}, {
"name" : "indices",
"type" : {
"type" : "array",
"elementType" : "integer",
"containsNull" : false
},
"nullable" : true,
"metadata" : { }
}, {
"name" : "values",
"type" : {
"type" : "array",
"elementType" : "double",
"containsNull" : false
},
"nullable" : true,
"metadata" : { }
} ]
}
}
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4732 次 |
| 最近记录: |