Pyspark:滚动窗口中的聚合模式(最常见)值

Car*_*hen 1 group-by apache-spark apache-spark-sql rolling-computation pyspark

我有一个如下所示的数据框。我想在每个组内进行分组device和排序start_time。然后,对于组中的每一行,从其前面 3 行(包括其自身)的窗口中获取最常出现的站点。

columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]


test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)
Run Code Online (Sandbox Code Playgroud)

期望的输出:

columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]


test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)
Run Code Online (Sandbox Code Playgroud)

由于 Pyspark 没有mode()函数,我知道如何获取静态中最常见的值,groupby如下所示但我不知道如何使其适应滚动窗口。

bla*_*hop 6

您可以使用collect_list函数使用定义的窗口获取最后 3 行的电台,然后为每个结果数组计算最常见的元素。

要获取数组中最常见的元素,您可以将其分解,然后按照您已经看到的链接帖子中的方式进行分组和计数,或者使用如下 UDF:

import pyspark.sql.functions as F

test_df.withColumn(
    "rolling_mode_station",
    F.collect_list("station").over(rolling_w)
).withColumn(
    "rolling_mode_station",
    F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()

#+------+----------+---------+--------------------+
#|device|start_time|  station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python|         1|station_1|           station_1|
#|Python|         2|station_2|           station_1|
#|Python|         3|station_1|           station_1|
#|Python|         4|station_2|           station_2|
#|Python|         5|station_2|           station_2|
#|Python|         6|     null|           station_2|
#+------+----------+---------+--------------------+
Run Code Online (Sandbox Code Playgroud)