yjx*_*yjx 4 scala apache-spark udf
在我的项目中,我想实现ADD(+)函数,但我的参数可能LongType是DoubleType,IntType.我用sqlContext.udf.register("add",XXX),但我不知道怎么写XXX,这是制作泛型函数.
您可以UDF通过创建一个包含您的值并使您的工作脱离此功能的StructTypewith 来创建泛型.它会作为对象传递给您,因此您可以执行以下操作:struct($"col1", $"col2")UDFUDFRow
val multiAdd = udf[Double,Row](r => {
var n = 0.0
r.toSeq.foreach(n1 => n = n + (n1 match {
case l: Long => l.toDouble
case i: Int => i.toDouble
case d: Double => d
case f: Float => f.toDouble
}))
n
})
val df = Seq((1.0,2),(3.0,4)).toDF("c1","c2")
df.withColumn("add", multiAdd(struct($"c1", $"c2"))).show
+---+---+---+
| c1| c2|add|
+---+---+---+
|1.0| 2|3.0|
|3.0| 4|7.0|
+---+---+---+
Run Code Online (Sandbox Code Playgroud)
您甚至可以做一些有趣的事情,例如将可变数量的列作为输入.事实上,我们UDF上面定义的已经做到了:
val df = Seq((1, 2L, 3.0f,4.0),(5, 6L, 7.0f,8.0)).toDF("int","long","float","double")
df.printSchema
root
|-- int: integer (nullable = false)
|-- long: long (nullable = false)
|-- float: float (nullable = false)
|-- double: double (nullable = false)
df.withColumn("add", multiAdd(struct($"int", $"long", $"float", $"double"))).show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0|10.0|
| 5| 6| 7.0| 8.0|26.0|
+---+----+-----+------+----+
Run Code Online (Sandbox Code Playgroud)
您甚至可以在混音中添加硬编码的数字:
df.withColumn("add", multiAdd(struct(lit(100), $"int", $"long"))).show
+---+----+-----+------+-----+
|int|long|float|double| add|
+---+----+-----+------+-----+
| 1| 2| 3.0| 4.0|103.0|
| 5| 6| 7.0| 8.0|111.0|
+---+----+-----+------+-----+
Run Code Online (Sandbox Code Playgroud)
如果要使用UDFin SQL语法,可以执行以下操作:
sqlContext.udf.register("multiAdd", (r: Row) => {
var n = 0.0
r.toSeq.foreach(n1 => n = n + (n1 match {
case l: Long => l.toDouble
case i: Int => i.toDouble
case d: Double => d
case f: Float => f.toDouble
}))
n
})
df.registerTempTable("df")
// Note that 'int' and 'long' are column names
sqlContext.sql("SELECT *, multiAdd(struct(int, long)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0| 3.0|
| 5| 6| 7.0| 8.0|11.0|
+---+----+-----+------+----+
Run Code Online (Sandbox Code Playgroud)
这也有效:
sqlContext.sql("SELECT *, multiAdd(struct(*)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0|10.0|
| 5| 6| 7.0| 8.0|26.0|
+---+----+-----+------+----+
Run Code Online (Sandbox Code Playgroud)
我认为您无法注册通用 UDF。
如果我们看一下该方法的签名register(实际上,它只是 22 个重载之一register,用于带有一个参数的 UDF,其他都是等效的):
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction
Run Code Online (Sandbox Code Playgroud)
我们可以看到它是用A1: TypeTag类型参数化的 - TypeTag 意味着在注册时,我们必须有UDF 参数的实际类型的证据。因此 - 传递通用函数func而不显式键入它无法编译。
对于您的情况,您也许能够利用 Spark 自动转换数字类型的功能 - 仅为Doubles 编写 UDF,并且您也可以将其应用于Ints (不过,输出将为Double):
sqlContext.udf.register("add", (i: Double) => i + 1)
// creating a table with Double and Int types:
sqlContext.createDataFrame(Seq((1.5, 4), (2.2, 5))).registerTempTable("table1")
// applying UDF to both types:
sqlContext.sql("SELECT add(_1), add(_2) FROM table1").show()
// output:
// +---+---+
// |_c0|_c1|
// +---+---+
// |2.5|5.0|
// |3.2|6.0|
// +---+---+
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3362 次 |
| 最近记录: |