Spark UDAF - 使用泛型作为输入类型?

ser*_*eda 2 scala aggregate-functions user-defined-functions apache-spark apache-spark-sql

我想编写Spark UDAF,其中列的类型可以是任何在其上定义了Scala Numeric的列.我已经搜查了互联网,但发现只有具体类型,如例子DoubleType,LongType.这不可能吗?但是如何将UDAF与其他数值一起使用呢?

use*_*411 10

为简单起见,我们假设您要定义自定义sum.您将为TypeTag输入类型提供a 并使用Scala反射来定义模式:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor

case class MySum [T : TypeTag](implicit n: Numeric[T]) 
    extends UserDefinedAggregateFunction {

  val dt = schemaFor[T].dataType
  def inputSchema = new StructType().add("x", dt)
  def bufferSchema = new StructType().add("x", dt)

  def dataType = dt
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = buffer.update(0,  n.zero)
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, n.plus(buffer.getAs[T](0), input.getAs[T](0)))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, n.plus(buffer1.getAs[T](0),  buffer2.getAs[T](0)))    
  }

  def evaluate(buffer: Row) = buffer.getAs[T](0)
}
Run Code Online (Sandbox Code Playgroud)

使用上面定义的函数,我们可以创建实例处理特定类型:

val sumOfLong = MySum[Long]
spark.range(10).select(sumOfLong($"id")).show
Run Code Online (Sandbox Code Playgroud)
+---------+
|mysum(id)|
+---------+
|       45|
+---------+
Run Code Online (Sandbox Code Playgroud)

注意:

要获得与内置聚合函数相同的灵活性,您必须定义自己的AggregateFunction,如ImperativeAggregateDeclarativeAggregate.它是可能的,但它是一个内部API.