pir*_*x22 5 scala limit dataframe apache-spark
我有以下格式的数据框:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
Run Code Online (Sandbox Code Playgroud)
我想要做的是按 对数据框进行分组name,收集列表并限制列表的大小。
这是我分组name和收集列表的方式:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
Run Code Online (Sandbox Code Playgroud)
结果数据框类似于:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
Run Code Online (Sandbox Code Playgroud)
我想要做的是限制每个键生成的列表的大小。我尝试了多种方法来做到这一点,但都没有成功。我已经看到一些建议使用 3rd 方解决方案的帖子,但我想避免这种情况。有办法吗?
因此,虽然 UDF 可以满足您的需要,但如果您正在寻找一种性能更高且对内存敏感的方法,则可以编写 UDAF。不幸的是,UDAF API 实际上不如 Spark 附带的聚合函数那么可扩展。但是,您可以使用其内部 API 来构建内部函数来执行您需要的操作。
下面是一个实现,collect_list_limit主要是 Spark 内部 AggregateFunction 的复制CollectList。我只想扩展它,但它是一个案例类。实际上,所需要的只是重写更新和合并方法以尊重传入的限制:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
Run Code Online (Sandbox Code Playgroud)
要实际注册它,我们可以通过 Spark 的内部来完成FunctionRegistry,它接受名称和构建器,它实际上是一个CollectListLimit使用提供的表达式创建的函数:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
Run Code Online (Sandbox Code Playgroud)
编辑:
事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置中才有效,因为它会在启动时创建不可变的克隆。如果您有现有的上下文,那么这应该可以通过反射添加它:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
Run Code Online (Sandbox Code Playgroud)
您可以创建一个函数来限制聚合 ArrayType 列的大小,如下所示:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column
case class KV(k: String, v: String)
val df = Seq(
("key1", KV("internalKey1", "value1")),
("key1", KV("internalKey2", "value2")),
("key2", KV("internalKey3", "value3")),
("key2", KV("internalKey4", "value4")),
("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")
def limitSize(n: Int, arrCol: Column): Column =
array( (0 until n).map( arrCol.getItem ): _* )
df.
groupBy("name").agg( collect_list(col("merged")).as("final") ).
select( $"name", limitSize(2, $"final").as("final2") ).
show(false)
// +----+----------------------------------------------+
// |name|final2 |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
10364 次 |
| 最近记录: |