147*_*580 6 shuffle aggregation partition apache-spark
我正在尝试在多个列上聚合数据框。我知道聚合所需的所有内容都在分区内 - 也就是说,不需要洗牌,因为聚合的所有数据都是分区本地的。
举个例子,如果我有类似的东西
val sales=sc.parallelize(List(
("West", "Apple", 2.0, 10),
("West", "Apple", 3.0, 15),
("West", "Orange", 5.0, 15),
("South", "Orange", 3.0, 9),
("South", "Orange", 6.0, 18),
("East", "Milk", 5.0, 5))).repartition(2)
val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
println(tdf.toDebugString)
Run Code Online (Sandbox Code Playgroud)
我得到的结果是这样的
(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
+-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
| MapPartitionsRDD[10] at repartition at Test.scala:57 []
| CoalescedRDD[9] at repartition at Test.scala:57 []
| ShuffledRDD[8] at repartition at Test.scala:57 []
+-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
| ParallelCollectionRDD[6] at parallelize at Test.scala:51 []
Run Code Online (Sandbox Code Playgroud)
你可以看到MapPartitionsRDD,这很好。但是还有 ShuffleRDD,我想阻止它,因为我想要按分区进行汇总,并按分区内的列值进行分组。
参考上面的示例,我正在寻找将产生的结果
select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod
Run Code Online (Sandbox Code Playgroud)
(我真的不需要分区 ID - 这只是为了说明我想要每个分区的结果)
我看过很多 例子,但我生成的每个调试字符串都有随机播放。 我真的希望摆脱洗牌。我想我本质上是在寻找 groupByKeysWithinPartitions 函数。
实现这一目标的唯一方法是使用 mapPartitions 并使用自定义代码在迭代分区时对值进行分组和计算。正如您提到的,数据已经按分组键(store、prod)排序,我们可以以管道方式有效地计算您的聚合:
(1) 定义辅助类:
:paste
case class MyRec(store: String, prod: String, amt: Double, units: Int)
case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)
object MyResult {
def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)
def aggregate(result: MyResult, rec: MyRec) = {
new MyResult(result.store,
result.prod,
result.total_amt + rec.amt,
math.min(result.min_amt, rec.amt),
math.max(result.max_amt, rec.amt),
result.total_units + rec.units
)
}
}
Run Code Online (Sandbox Code Playgroud)
(2) 定义流水线聚合器:
:paste
def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {
var prev: MyResult = null
var res: Seq[MyResult] = Nil
for (crt <- iter) yield {
if (prev == null) {
prev = MyResult(crt)
}
else if (prev.prod != crt.prod || prev.store != crt.store) {
res = Seq(prev)
prev = MyResult(crt)
}
else {
prev = MyResult.aggregate(prev, crt)
}
if (!iter.hasNext) {
res = res ++ Seq(prev)
}
res
}
Run Code Online (Sandbox Code Playgroud)
}
(3)运行聚合:
:paste
val sales = sc.parallelize(
List(MyRec("West", "Apple", 2.0, 10),
MyRec("West", "Apple", 3.0, 15),
MyRec("West", "Orange", 5.0, 15),
MyRec("South", "Orange", 3.0, 9),
MyRec("South", "Orange", 6.0, 18),
MyRec("East", "Milk", 5.0, 5),
MyRec("West", "Apple", 7.0, 11)), 2).toDS
sales.mapPartitions(iter => Iterator(iter.toList)).show(false)
val result = sales
.mapPartitions(recIter => pipelinedAggregator(recIter))
.flatMap(identity)
result.show
result.explain
Run Code Online (Sandbox Code Playgroud)
输出:
+-------------------------------------------------------------------------------------+
|value |
+-------------------------------------------------------------------------------------+
|[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]] |
|[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
+-------------------------------------------------------------------------------------+
+-----+------+---------+-------+-------+-----------+
|store| prod|total_amt|min_amt|max_amt|total_units|
+-----+------+---------+-------+-------+-----------+
| West| Apple| 5.0| 2.0| 3.0| 25|
| West|Orange| 5.0| 5.0| 5.0| 15|
|South|Orange| 9.0| 3.0| 6.0| 27|
| East| Milk| 5.0| 5.0| 5.0| 5|
| West| Apple| 7.0| 7.0| 7.0| 11|
+-----+------+---------+-------+-------+-----------+
== Physical Plan ==
*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
+- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
+- MapPartitions <function1>, obj#20: scala.collection.Seq
+- Scan ExternalRDDScan[obj#4]
sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
6710 次 |
| 最近记录: |