在分区内的多个列上进行 Spark 聚合,无需进行洗牌

147*_*580 6 shuffle aggregation partition apache-spark

我正在尝试在多个列上聚合数据框。我知道聚合所需的所有内容都在分区内 - 也就是说,不需要洗牌,因为聚合的所有数据都是分区本地的。

举个例子,如果我有类似的东西

        val sales=sc.parallelize(List(
        ("West",  "Apple",  2.0, 10),
        ("West",  "Apple",  3.0, 15),
        ("West",  "Orange", 5.0, 15),
        ("South", "Orange", 3.0, 9),
        ("South", "Orange", 6.0, 18),
        ("East",  "Milk",   5.0, 5))).repartition(2)
        val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
        reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
      println(tdf.toDebugString)
Run Code Online (Sandbox Code Playgroud)

我得到的结果是这样的

(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
 +-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
    |  MapPartitionsRDD[10] at repartition at Test.scala:57 []
    |  CoalescedRDD[9] at repartition at Test.scala:57 []
    |  ShuffledRDD[8] at repartition at Test.scala:57 []
    +-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
       |  ParallelCollectionRDD[6] at parallelize at Test.scala:51 []
Run Code Online (Sandbox Code Playgroud)

你可以看到MapPartitionsRDD,这很好。但是还有 ShuffleRDD,我想阻止它,因为我想要按分区进行汇总,并按分区内的列值进行分组。

Zero323建议非常接近,但我需要“按列分组”功能。

参考上面的示例,我正在寻找将产生的结果

select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod
Run Code Online (Sandbox Code Playgroud)

(我真的不需要分区 ID - 这只是为了说明我想要每个分区的结果)

我看过很多 例子,但我生成的每个调试字符串都有随机播放 我真的希望摆脱洗牌。我想我本质上是在寻找 groupByKeysWithinPartitions 函数。

Tra*_*ian 4

实现这一目标的唯一方法是使用 mapPartitions 并使用自定义代码在迭代分区时对值进行分组和计算。正如您提到的,数据已经按分组键(store、prod)排序,我们可以以管道方式有效地计算您的聚合:

(1) 定义辅助类:

:paste

case class MyRec(store: String, prod: String, amt: Double, units: Int)

case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)

object MyResult {
  def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)

  def aggregate(result: MyResult, rec: MyRec) = {
    new MyResult(result.store,
      result.prod,
      result.total_amt + rec.amt,
      math.min(result.min_amt, rec.amt),
      math.max(result.max_amt, rec.amt),
      result.total_units + rec.units
    )
  }
}
Run Code Online (Sandbox Code Playgroud)

(2) 定义流水线聚合器:

:paste

def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {

var prev: MyResult = null
var res: Seq[MyResult] = Nil

for (crt <- iter) yield {
  if (prev == null) {
    prev = MyResult(crt)
  }
  else if (prev.prod != crt.prod || prev.store != crt.store) {
    res = Seq(prev)
    prev = MyResult(crt)
  }
  else {
    prev = MyResult.aggregate(prev, crt)
  }

  if (!iter.hasNext) {
    res = res ++ Seq(prev)
  }

  res
}
Run Code Online (Sandbox Code Playgroud)

}

(3)运行聚合:

:paste

val sales = sc.parallelize(
  List(MyRec("West", "Apple", 2.0, 10),
    MyRec("West", "Apple", 3.0, 15),
    MyRec("West", "Orange", 5.0, 15),
    MyRec("South", "Orange", 3.0, 9),
    MyRec("South", "Orange", 6.0, 18),
    MyRec("East", "Milk", 5.0, 5),
    MyRec("West", "Apple", 7.0, 11)), 2).toDS

sales.mapPartitions(iter => Iterator(iter.toList)).show(false)

val result = sales
  .mapPartitions(recIter => pipelinedAggregator(recIter))
  .flatMap(identity)

result.show
result.explain
Run Code Online (Sandbox Code Playgroud)

输出:

    +-------------------------------------------------------------------------------------+
    |value                                                                                |
    +-------------------------------------------------------------------------------------+
    |[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]]                     |
    |[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
    +-------------------------------------------------------------------------------------+

    +-----+------+---------+-------+-------+-----------+
    |store|  prod|total_amt|min_amt|max_amt|total_units|
    +-----+------+---------+-------+-------+-----------+
    | West| Apple|      5.0|    2.0|    3.0|         25|
    | West|Orange|      5.0|    5.0|    5.0|         15|
    |South|Orange|      9.0|    3.0|    6.0|         27|
    | East|  Milk|      5.0|    5.0|    5.0|          5|
    | West| Apple|      7.0|    7.0|    7.0|         11|
    +-----+------+---------+-------+-------+-----------+

    == Physical Plan ==
    *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
    +- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
       +- MapPartitions <function1>, obj#20: scala.collection.Seq
          +- Scan ExternalRDDScan[obj#4]
    sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
    result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]    
Run Code Online (Sandbox Code Playgroud)