Spark 中的累积乘积?

Mar*_*rti 0 scala aggregation apache-spark apache-spark-sql

我尝试在 Spark scala 中实现一个累积产品,但我真的不知道如何实现。我有以下数据框:

Input data:
+--+--+--------+----+
|A |B | date   | val|
+--+--+--------+----+
|rr|gg|20171103| 2  |
|hh|jj|20171103| 3  |
|rr|gg|20171104| 4  |
|hh|jj|20171104| 5  |
|rr|gg|20171105| 6  |
|hh|jj|20171105| 7  |
+-------+------+----+
Run Code Online (Sandbox Code Playgroud)

我想要以下输出

Output data:
+--+--+--------+-----+
|A |B | date   | val |
+--+--+--------+-----+
|rr|gg|20171105| 48  | // 2 * 4 * 6
|hh|jj|20171105| 105 | // 3 * 5 * 7
+-------+------+-----+
Run Code Online (Sandbox Code Playgroud)

如果您对如何做有任何想法,那将非常有帮助:)

非常感谢

zer*_*323 8

只要数字是严格的正数(也可以处理 0,如果存在,使用coalesce),如您的示例中所示,最简单的解决方案是计算对数之和并取指数:

import org.apache.spark.sql.functions.{exp, log, max, sum}

val df = Seq(
  ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
  ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
  ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"), 
    exp(sum(log($"val"))).as("val"))
Run Code Online (Sandbox Code Playgroud)

由于这使用 FP 算术,因此结果将不准确:

result.show
Run Code Online (Sandbox Code Playgroud)
+---+---+--------+------------------+
|  A|  B|    date|               val|
+---+---+--------+------------------+
| hh| jj|20171105|104.99999999999997|
| rr| gg|20171105|47.999999999999986|
+---+---+--------+------------------+
Run Code Online (Sandbox Code Playgroud)

但四舍五入后应该足以满足大多数应用程序。

result.withColumn("val", round($"val")).show
Run Code Online (Sandbox Code Playgroud)
import org.apache.spark.sql.functions.{exp, log, max, sum}

val df = Seq(
  ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
  ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
  ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"), 
    exp(sum(log($"val"))).as("val"))
Run Code Online (Sandbox Code Playgroud)

如果这还不够,您可以定义一个UserDefinedAggregateFunctionAggregator如何在 Spark SQL 中定义和使用用户定义的聚合函数?)或使用具有reduceGroups以下功能的 API :

import scala.math.Ordering

case class Record(A: String, B: String, date: String, value: Long)

df.withColumnRenamed("val", "value").as[Record]
  .groupByKey(x => (x.A, x.B))
  .reduceGroups((x, y) => x.copy(
    date = Ordering[String].max(x.date, y.date),
    value = x.value * y.value))
  .toDF("key", "value")
  .select($"value.*")
  .show
Run Code Online (Sandbox Code Playgroud)
result.show
Run Code Online (Sandbox Code Playgroud)