Apache Spark SQL UDAF over window显示重复输入的奇怪行为

ab8*_*853 13 apache-spark apache-spark-sql

我发现在Apache Spark SQL(版本2.2.0)中,当在窗口规范上使用的用户定义的聚合函数(UDAF)提供了多行相同的输入时,UDAF(看似)不会调用evaluate方法正确.

我已经能够在Java和Scala中,本地和群集上重现这种行为.下面的代码显示了一个示例,如果行在前一行的1秒内,则标记为false.

class ExampleUDAF(val timeLimit: Long) extends UserDefinedAggregateFunction {
  def deterministic: Boolean = true
  def inputSchema: StructType = StructType(Array(StructField("unix_time", LongType)))
  def dataType: DataType = BooleanType

  def bufferSchema = StructType(Array(
    StructField("previousKeepTime", LongType),
    StructField("keepRow", BooleanType)
  ))

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0L
    buffer(1) = false
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {    
    if (buffer(0) == 0L) {
      buffer(0) = input.getLong(0)
      buffer(1) = true
    } else {
      val timeDiff = input.getLong(0) - buffer.getLong(0)

      if (timeDiff < timeLimit) {
        buffer(1) = false
      } else {
        buffer(0) = input.getLong(0)
        buffer(1) = true
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {} // Not implemented
  def evaluate(buffer: Row): Boolean = buffer.getBoolean(1)
 }

val timeLimit = 1000 // 1 second
val udaf = new ExampleUDAF(timeLimit)

val window = Window
  .orderBy(column("unix_time"))
  .partitionBy(column("category"))

val df = spark.createDataFrame(Arrays.asList(
    Row(1510000001000L, "a", true), 
    Row(1510000001000L, "a", false), 
    Row(1510000001000L, "a", false),
    Row(1510000001000L, "a", false),
    Row(1510000700000L, "a", true),
    Row(1510000700000L, "a", false)
  ), new StructType().add("unix_time", LongType).add("category", StringType).add("expected_result", BooleanType))

df.withColumn("actual_result", udaf(column("unix_time")).over(window)).show
Run Code Online (Sandbox Code Playgroud)

下面是运行上面代码的输出.由于actual_result没有先前的数据,因此第一行的值应为true.当unix_time输入被修改为在每条记录之间有1毫秒时,UDAF按预期工作.

在UDAF方法中添加print语句显示最后evaluate只调用一次,并且该方法中缓冲区已正确更新为true update,但这不是UDAF完成后返回的内容.

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|           true|        false|  // Should true as first element
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000700000|       a|           true|        false|  // Should be true as more than 1000 milliseconds between self and previous
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+
Run Code Online (Sandbox Code Playgroud)

当在窗口规范上使用时,我正确理解Spark的UDAF行为?如果没有,任何人都可以提供这方面的任何见解.如果我对Windows上的UDAF行为的理解是正确的,那么这可能是Spark中的一个错误吗?谢谢.

ast*_*asz 8

您的UDAF的一个问题是它没有指定您要在哪些行上运行窗口rowsBetween().如果没有rowsBetween()规范,则对于每一行,窗口函数将占用当前一行(包括当前类别)之前和之后的所有行(参见下面的更新)行.因此,actual_result所有行基本上都会考虑到只有两个在你的例子最后一行DataFrame,用unix_time=1510000700000它有效地将返回false所有行.

有了这样的window声明:

Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)
Run Code Online (Sandbox Code Playgroud)

您始终只查看上一行和当前行.前一行先行.这会创建正确的输出.但由于具有相同行的排序unix_time不是唯一的,因此无法预测哪一true行在具有相同的行中具有值unix_time.

结果可能如下所示:

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|          false|         true|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|           true|        false|
|1510000700000|       a|           true|         true|
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+
Run Code Online (Sandbox Code Playgroud)

更新

在进一步调查之后,似乎在orderBy提供列时它将获取当前行+当前行之前的所有元素.不是像我之前说过的所有分区元素.此外,如果orderBy列包含重复值,则每个重复行的窗口将包含所有重复值.你可以通过这样做清楚地看到它:

val wA = Window.partitionBy(col("category")).orderBy(col("unix_time"))
val wB = Window.partitionBy(col("category"))
val wC = Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)

df.withColumn("countRows", count(col("unix_time")).over(wA)).show()
df.withColumn("countRows", count(col("unix_time")).over(wB)).show()
df.withColumn("countRows", count(col("unix_time")).over(wC)).show()
Run Code Online (Sandbox Code Playgroud)

这将计算每个窗口中元素的数量.

  • 窗口wA每个1510000001000行有4个元素,每个1510000700000有6个元素.
  • 因为wBorderBy每个分区的窗口中都没有包含所有行时,所以所有窗口都将有6个元素.
  • 最后一个wC指定了行的选择,因此不会产生歧义,为哪个窗口选择哪些行.所有后续行的第一行只有1个元素,窗口中只有2个元素.这产生了正确的结果.

我今天也学到了新东西:)