在pyspark中的每个DataFrame组中检索前n个

KAs*_*KAs 36 python dataframe apache-spark apache-spark-sql pyspark

pyspark中有一个DataFrame,数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6
Run Code Online (Sandbox Code Playgroud)

我期望在每个组中返回具有相同user_id的2条记录,这些记录需要具有最高分.因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5
Run Code Online (Sandbox Code Playgroud)

我是pyspark的新手,有人能给我一个代码片段或门户网站来解决这个问题的相关文档吗?十分感谢!

mto*_*oto 54

我相信你需要使用窗口函数来获得基于user_id和的每一行的等级score,然后过滤你的结果只保留前两个值.

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+
Run Code Online (Sandbox Code Playgroud)

一般来说,官方编程指南是开始学习Spark的好地方.

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
Run Code Online (Sandbox Code Playgroud)

  • 您可以在过滤器中使用窗口函数:`df.filter(rank().over(window)<= 2)` (2认同)
  • 我大吃一惊...我确信我之前在过滤器中使用过窗口函数。但我确实无法重现它(无论是在 2 还是 1.6 中)。我确实以一种奇异的方式使用了它,但我不记得何时或如何使用它。对不起。 (2认同)
  • 您可能需要考虑使用 `row_number` 而不是 `rank`,以防获得相同的排名并且您仍然想要前 n (2认同)

Mar*_*app 19

如果使用Top-n row_number而不是rank在获得等级相等时更准确:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()
Run Code Online (Sandbox Code Playgroud)

注意limit(20).toPandas()技巧而不是show()Jupyter笔记本更好的格式化.

  • 请记住添加“from pyspark.sql.functions import row_number”才能正常工作 (2认同)