Spark/Scala使用多个列上的相同函数重复调用withColumn()

Dam*_*ips 15 scala user-defined-functions dataframe apache-spark apache-spark-sql

我目前有代码,我通过多个.withColumn链重复将相同的过程应用于多个DataFrame列,并且我想创建一个简化过程的函数.在我的情况下,我发现由键聚合的列的累积总和:

val newDF = oldDF
  .withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
  //.withColumn(...)
Run Code Online (Sandbox Code Playgroud)

我想要的是:

def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
  // Implement the above cumulative sums, partitioning, and ordering
}
Run Code Online (Sandbox Code Playgroud)

或者更好的是:

def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
  // Implement a udf/arbitrary function on all the specified columns
}
Run Code Online (Sandbox Code Playgroud)

use*_*411 27

您可以使用selectvarargs,包括*:

import spark.implicits._

df.select($"*" +: Seq("A", "B", "C").map(c => 
  sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
): _*)
Run Code Online (Sandbox Code Playgroud)

这个:

  • 将列名称映射到窗口表达式 Seq("A", ...).map(...)
  • 预先添加所有预先存在的列$"*" +: ....
  • 解压缩组合序列... : _*.

并可以概括为:

import org.apache.spark.sql.{Column, DataFrame}

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 */
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
  df.select($"*" +: cols.map(c => f(c)): _*)
Run Code Online (Sandbox Code Playgroud)

如果您发现withColumn语法更具可读性,则可以使用foldLeft:

Seq("A", "B", "C").foldLeft(df)((df, c) =>
  df.withColumn(s"cum$c",  sum(c).over(Window.partitionBy("ID").orderBy("time")))
)
Run Code Online (Sandbox Code Playgroud)

这可以概括为例如:

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 * @param name a function mapping from input to output name.
 */
def withColumns(cols: Seq[String], df: DataFrame, 
    f: String =>  Column, name: String => String = identity) =
  cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))
Run Code Online (Sandbox Code Playgroud)


小智 5

这个问题有点老了,但是我认为这很有用(也许对其他人而言),请注意,当列数不平凡时,使用DataFrameas累加器将列折叠到表上,并映射到DataFrame具有明显不同的性能结果(完整说明请参见此处)。长话短说...对于几列foldLeft来说很好,否则map更好。