Spark collect_list 并限制结果列表

pir*_*x22 5 scala limit dataframe apache-spark

我有以下格式的数据框:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...
Run Code Online (Sandbox Code Playgroud)

我想要做的是按 对数据框进行分组name,收集列表并限制列表的大小。

这是我分组name和收集列表的方式:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))
Run Code Online (Sandbox Code Playgroud)

结果数据框类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]
Run Code Online (Sandbox Code Playgroud)

我想要做的是限制每个键生成的列表的大小。我尝试了多种方法来做到这一点,但都没有成功。我已经看到一些建议使用 3rd 方解决方案的帖子,但我想避免这种情况。有办法吗?

use*_*563 6

因此,虽然 UDF 可以满足您的需要,但如果您正在寻找一种性能更高且对内存敏感的方法,则可以编写 UDAF。不幸的是,UDAF API 实际上不如 Spark 附带的聚合函数那么可扩展。但是,您可以使用其内部 API 来构建内部函数来执行您需要的操作。

下面是一个实现,collect_list_limit主要是 Spark 内部 AggregateFunction 的复制CollectList。我只想扩展它,但它是一个案例类。实际上,所需要的只是重写更新和合并方法以尊重传入的限制:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}
Run Code Online (Sandbox Code Playgroud)

要实际注册它,我们可以通过 Spark 的内部来完成FunctionRegistry,它接受名称和构建器,它实际上是一个CollectListLimit使用提供的表达式创建的函数:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
Run Code Online (Sandbox Code Playgroud)

编辑:

事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置中才有效,因为它会在启动时创建不可变的克隆。如果您有现有的上下文,那么这应该可以通过反射添加它:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
Run Code Online (Sandbox Code Playgroud)


Leo*_*o C 5

您可以创建一个函数来限制聚合 ArrayType 列的大小,如下所示:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
Run Code Online (Sandbox Code Playgroud)