Spark SQL:在数组值上使用collect_set?

sha*_*dzy 5 java apache-spark apache-spark-sql

我有一个聚合的 DataFrame,其中有一列使用collect_set. 我现在需要再次聚合这个 DataFrame,并再次应用于collect_set该列的值。问题是我需要应用collect_Set集合的值 - 到目前为止,我看到的唯一方法是分解聚合的 DataFrame。有没有更好的办法?

例子:

初始数据帧:

country   | continent   | attributes
-------------------------------------
Canada    | America     | A
Belgium   | Europe      | Z
USA       | America     | A
Canada    | America     | B
France    | Europe      | Y
France    | Europe      | X
Run Code Online (Sandbox Code Playgroud)

聚合数据帧(我作为输入接收的那个) - 聚合country

country   | continent   | attributes
-------------------------------------
Canada    | America     | A, B
Belgium   | Europe      | Z
USA       | America     | A
France    | Europe      | Y, X
Run Code Online (Sandbox Code Playgroud)

我想要的输出 - 聚合continent

continent   | attributes
-------------------------------------
America     | A, B
Europe      | X, Y, Z
Run Code Online (Sandbox Code Playgroud)

104*_*ica 5

由于此时您只能拥有少量行,因此您只需按原样收集属性并将结果展平(Spark >= 2.4)

import org.apache.spark.sql.functions.{collect_set, flatten, array_distinct}

val byState = Seq(
  ("Canada", "America", Seq("A", "B")),
  ("Belgium", "Europe", Seq("Z")),
  ("USA", "America", Seq("A")),
  ("France", "Europe", Seq("Y", "X"))
).toDF("country", "continent", "attributes")

byState
  .groupBy("continent")
  .agg(array_distinct(flatten(collect_set($"attributes"))) as "attributes")
  .show
Run Code Online (Sandbox Code Playgroud)
import org.apache.spark.sql.functions.{collect_set, flatten, array_distinct}

val byState = Seq(
  ("Canada", "America", Seq("A", "B")),
  ("Belgium", "Europe", Seq("Z")),
  ("USA", "America", Seq("A")),
  ("France", "Europe", Seq("Y", "X"))
).toDF("country", "continent", "attributes")

byState
  .groupBy("continent")
  .agg(array_distinct(flatten(collect_set($"attributes"))) as "attributes")
  .show
Run Code Online (Sandbox Code Playgroud)

在一般情况下,事情更难处理,并且在许多情况下,如果您期望大列表,每个组有很多重复和很多值,最佳解决方案*是从头开始重新计算结果,即

input.groupBy($"continent").agg(collect_set($"attributes") as "attributes")
Run Code Online (Sandbox Code Playgroud)

一种可能的替代方法是使用 Aggregator

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Encoder, Encoders}
import scala.collection.mutable.{Set => MSet}


class MergeSets[T, U](f: T => Seq[U])(implicit enc: Encoder[Seq[U]]) extends 
     Aggregator[T, MSet[U], Seq[U]] with Serializable {

  def zero = MSet.empty[U]

  def reduce(acc: MSet[U], x: T) = {
    for { v <- f(x) } acc.add(v)
    acc
  }

  def merge(acc1: MSet[U], acc2: MSet[U]) = {
    acc1 ++= acc2
  }

  def finish(acc: MSet[U]) = acc.toSeq
  def bufferEncoder: Encoder[MSet[U]] = Encoders.kryo[MSet[U]]
  def outputEncoder: Encoder[Seq[U]] = enc

}
Run Code Online (Sandbox Code Playgroud)

并按如下方式应用它

case class CountryAggregate(
  country: String, continent: String, attributes: Seq[String])

byState
  .as[CountryAggregate]
  .groupByKey(_.continent)
  .agg(new MergeSets[CountryAggregate, String](_.attributes).toColumn)
  .toDF("continent", "attributes")
  .show
Run Code Online (Sandbox Code Playgroud)
+---------+----------+
|continent|attributes|
+---------+----------+
|   Europe| [Y, X, Z]|
|  America|    [A, B]|
+---------+----------+
Run Code Online (Sandbox Code Playgroud)

但这显然不是一个对 Java 友好的选项。

另请参阅如何在 groupBy 之后将值聚合到集合中?(类似,但没有唯一性约束)。


* 那是因为explode可能非常昂贵,尤其是在较旧的 Spark 版本中,与访问 SQL 集合的外部表示相同。