Spark - 带递归的窗口?- 有条件地跨行传播值

Han*_*ans 1 window-functions apache-spark apache-spark-sql pyspark pyspark-sql

我有以下数据框显示购买收入。

+-------+--------+-------+
|user_id|visit_id|revenue|
+-------+--------+-------+
|      1|       1|      0|
|      1|       2|      0|
|      1|       3|      0|
|      1|       4|    100|
|      1|       5|      0|
|      1|       6|      0|
|      1|       7|    200|
|      1|       8|      0|
|      1|       9|     10|
+-------+--------+-------+
Run Code Online (Sandbox Code Playgroud)

最终,我希望新列purch_revenue在每一行中显示购买产生的收入。作为一种解决方法,我还尝试引入一个购买标识符purch_id,每次购买时都会增加该标识符。所以这只是作为参考列出。

+-------+--------+-------+-------------+--------+
|user_id|visit_id|revenue|purch_revenue|purch_id|
+-------+--------+-------+-------------+--------+
|      1|       1|      0|          100|       1|
|      1|       2|      0|          100|       1|
|      1|       3|      0|          100|       1|
|      1|       4|    100|          100|       1|
|      1|       5|      0|          100|       2|
|      1|       6|      0|          100|       2|
|      1|       7|    200|          100|       2|
|      1|       8|      0|          100|       3|
|      1|       9|     10|          100|       3|
+-------+--------+-------+-------------+--------+
Run Code Online (Sandbox Code Playgroud)

我试图使用这样的lag/lead功能:

+-------+--------+-------+
|user_id|visit_id|revenue|
+-------+--------+-------+
|      1|       1|      0|
|      1|       2|      0|
|      1|       3|      0|
|      1|       4|    100|
|      1|       5|      0|
|      1|       6|      0|
|      1|       7|    200|
|      1|       8|      0|
|      1|       9|     10|
+-------+--------+-------+
Run Code Online (Sandbox Code Playgroud)

这将复制收入列 ifrevenue > 0并将其拉高一行。显然,我可以将其链接到有限的 N,但这不是解决方案。

  • 有没有办法递归地应用这个直到revenue > 0
  • 或者,有没有办法根据条件增加值?我试图找出一种方法来做到这一点,但很难找到。

zer*_*323 5

窗口函数不支持递归,但这里不需要。这种类型的分离可以通过累积和轻松处理:

from pyspark.sql.functions import col, sum, when, lag
from pyspark.sql.window import Window

w = Window.partitionBy("user_id").orderBy("visit_id")
purch_id = sum(lag(when(
    col("revenue") > 0, 1).otherwise(0), 
    1, 0
).over(w)).over(w) + 1

df.withColumn("purch_id", purch_id).show()
Run Code Online (Sandbox Code Playgroud)
+-------+--------+-------+--------+
|user_id|visit_id|revenue|purch_id|
+-------+--------+-------+--------+
|      1|       1|      0|       1|
|      1|       2|      0|       1|
|      1|       3|      0|       1|
|      1|       4|    100|       1|
|      1|       5|      0|       2|
|      1|       6|      0|       2|
|      1|       7|    200|       2|
|      1|       8|      0|       3|
|      1|       9|     10|       3|
+-------+--------+-------+--------+
Run Code Online (Sandbox Code Playgroud)