从Apache Spark SQL中的用户定义聚合函数(UDAF)返回多个数组

ab8*_*853 9 java aggregate-functions user-defined-functions apache-spark apache-spark-sql

我正在尝试使用Apache Spark SQL在Java中创建用户定义的聚合函数(UDAF),该函数在完成时返回多个数组.我在网上搜索过,找不到任何有关如何执行此操作的示例或建议.

我能够返回单个数组,但无法弄清楚如何在evaluate()方法中以正确的格式获取数据以返回多个数组.

UDAF确实有效,因为我可以在evaluate()方法中打印出数组,我只是无法弄清楚如何将这些数组返回到调用代码(下面显示以供参考).

UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
Run Code Online (Sandbox Code Playgroud)

我在下面包含了整个自定义UDAF类,但关键方法是dataType()和evaluate方法(),它们首先显示.

任何帮助或建议将不胜感激.谢谢.

public class CustomUDAF extends UserDefinedAggregateFunction {

    @Override
    public DataType dataType() {
        // TODO: Is this the correct way to return 2 arrays?
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }

    @Override
    public Object evaluate(Row buffer) {
        // Data conversion
        List<Long> longList = new ArrayList<Long>(buffer.getList(0));
        List<Double> dataList = new ArrayList<Double>(buffer.getList(1));

        // Processing of data (omitted)

        // TODO: How to get data into format needed to return 2 arrays?
        return dataList;
    }

    @Override
    public StructType inputSchema() {
        return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
    }

    @Override
    public StructType bufferSchema() {
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, new ArrayList<Long>());
        buffer.update(1, new ArrayList<Double>());
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row row) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
        longList.add(row.getLong(0));

        ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
        dataList.add(row.getDouble(1));

        buffer.update(0, longList);
        buffer.update(1, dataList);
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
        longList.addAll(buffer2.getList(0));

        ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
        dataList.addAll(buffer2.getList(1));

        buffer1.update(0, longList);
        buffer1.update(1, dataList);
    }

    @Override
    public boolean deterministic() {
        return true;
    }
}
Run Code Online (Sandbox Code Playgroud)

更新:基于zero323的答案,我能够使用以下命令返回两个数组:

return new Tuple2<>(longArray, dataArray);
Run Code Online (Sandbox Code Playgroud)

从中获取数据有点困难,但涉及将DataFrame解构为Java Lists,然后将其构建回DataFrame.

zer*_*323 7

据我所知,返回一个元组应该就够了.在斯卡拉:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}

object DummyUDAF extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", StringType)
  def bufferSchema = new StructType()
    .add("buff", ArrayType(LongType))
    .add("buff2", ArrayType(DoubleType))
  def dataType = new StructType()
    .add("xs", ArrayType(LongType))
    .add("ys", ArrayType(DoubleType))
  def deterministic = true 
  def initialize(buffer: MutableAggregationBuffer) = {}
  def update(buffer: MutableAggregationBuffer, input: Row) = {}
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
  def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}

val df =  sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v")
df.select(DummyUDAF($"k")).show(1, false)

// +---------------------------------------------------+
// |(DummyUDAF$(k),mode=Complete,isDistinct=false)     |
// +---------------------------------------------------+
// |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]|
// +---------------------------------------------------+
Run Code Online (Sandbox Code Playgroud)