在pyspark中应用用户定义的聚合函数的替代方法

Flo*_*ian 3 python user-defined-functions apache-spark pyspark

我正在尝试将用户定义的聚合函数应用于spark数据帧,以应用加法平滑,请参阅下面的代码:

import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext
from pyspark.sql.functions import col, col, collect_list, concat_ws, udf

try:
    sc
except NameError:
    sc = ps.SparkContext()
    sqlContext = SQLContext(sc)

df = sqlContext.createDataFrame([['A', 1],
                            ['A',1],
                            ['A',0],
                            ['B',0],
                            ['B',0],
                            ['B',1]], schema=['name', 'val'])


def smooth_mean(x):
    return (sum(x)+5)/(len(x)+5)

smooth_mean_udf = udf(smooth_mean)

df.groupBy('name').agg(collect_list('val').alias('val'))\
.withColumn('val', smooth_mean_udf('val')).show()
Run Code Online (Sandbox Code Playgroud)

这样做是否有意义?根据我的理解,这不能很好地扩展,因为我使用的是udf.我也没有找到确切的工作collect_list,collect名称中的部分似乎表明数据被"收集"到边缘节点,但我假设数据被"收集"到各个节点?

提前感谢您的任何反馈.

hi-*_*zir 5

根据我的理解,这不会扩展

你的理解是正确的,这里最大的问题是collect_list是刚刚好老groupByKey.Python的udf影响要小得多,但使用它对于简单的算术运算没有意义.

只需使用标准聚合

from pyspark.sql.functions import sum as sum_, count

(df
    .groupBy("name")
    .agg(((sum_("val") + 5) / (count("val") + 5)).alias("val"))
    .show())

# +----+-----+
# |name|  val|
# +----+-----+
# |   B| 0.75|
# |   A|0.875|
# +----+-----+
Run Code Online (Sandbox Code Playgroud)