我们如何为 Pyspark Dataframe 的列中最后一次出现的值设置标志

nim*_*ari 5 python sql window-functions pyspark

要求:这里当最后一次出现值为 1 的忠诚时,则将标志设置为 1,否则为 0

输入:

+-----------+----------+----------+-------+-----+---------+-------+---+
|consumer_id|product_id|    TRX_ID|pattern|loyal| trx_date|row_num| mx|
+-----------+----------+----------+-------+-----+---------+-------+---+
|         11|         1|1152397078|  VVVVM|    1| 3/5/2020|      1|  5|
|         11|         1|1152944770|  VVVVV|    1| 3/6/2020|      2|  5|
|         11|         1|1153856408|  VVVVV|    1|3/15/2020|      3|  5|
|         11|         2|1155884040|  MVVVV|    1| 4/2/2020|      4|  5|
|         11|         2|1156854301|  MMVVV|    0|4/17/2020|      5|  5|
|         12|         1|1156854302|  VVVVM|    1| 3/6/2020|      1|  3|
|         12|         1|1156854303|  VVVVV|    1| 3/7/2020|      2|  3|
|         12|         2|1156854304|  MVVVV|    1|3/16/2020|      3|  3|
+-----------+----------+----------+-------+-----+---------+-------+---+

df = spark.createDataFrame(
[('11','1','1152397078','VVVVM',1,'3/5/2020',1,5),
('11','1','1152944770','VVVVV',1,'3/6/2020',2,5),
('11','1','1153856408','VVVVV',1,'3/15/2020',3,5),
('11','2','1155884040','MVVVV',1,'4/2/2020',4,5),
('11','2','1156854301','MMVVV',0,'4/17/2020',5,5),
('12','1','1156854302','VVVVM',1,'3/6/2020',1,3),
('12','1','1156854303','VVVVV',1,'3/7/2020',2,3),
('12','2','1156854304','MVVVV',1,'3/16/2020',3,3)
]
,["consumer_id","product_id","TRX_ID","pattern","loyal","trx_date","row_num","mx"])
df.show()
Run Code Online (Sandbox Code Playgroud)

输出要求:

注意:这里的 Flag 只检查最后一个忠诚值是否包含 1 并设置标志。

+-----------+----------+----------+-------+-----+---------+-------+---+----+
|consumer_id|product_id|    TRX_ID|pattern|loyal| trx_date|row_num| mx|Flag|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
|         11|         1|1152397078|  VVVVM|    1| 3/5/2020|      1|  5|   0|
|         11|         1|1152944770|  VVVVV|    1| 3/6/2020|      2|  5|   0|
|         11|         1|1153856408|  VVVVV|    1|3/15/2020|      3|  5|   0|
|         11|         2|1155884040|  MVVVV|    1| 4/2/2020|      4|  5|   1|
|         11|         2|1156854301|  MMVVV|    0|4/17/2020|      5|  5|   0|
|         12|         1|1156854302|  VVVVM|    1| 3/6/2020|      1|  3|   0|
|         12|         1|1156854303|  VVVVV|    1| 3/7/2020|      2|  3|   0|
|         12|         2|1156854304|  MVVVV|    1|3/16/2020|      3|  3|   1|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
Run Code Online (Sandbox Code Playgroud)

我试过的:

from pyspark.sql import functions as F
from pyspark.sql.window import Window
w2 = Window().partitionBy("consumer_id").orderBy('row_num')
df = spark.sql("""select * from inter_table""")
df = df.withColumn("Flag",F.when(F.last(F.col('loyal') == 1).over(w),1).otherwise(0))
Run Code Online (Sandbox Code Playgroud)

这里有两种情况:

1. 值 1 前加 0(供您参考 row_num 4 for consumer_id 11)

2. 没有前面的值 1(供您参考 row_num 3 for consumer_id 12)