将组计数列添加到PySpark数据帧

Dav*_*ein 9 dplyr apache-spark pyspark

由于其优越的Spark处理能力,我来自R和tidyverse到PySpark,我正在努力将某些概念从一个上下文映射到另一个上下文.

特别是,假设我有一个如下所示的数据集

x | y
--+--
a | 5
a | 8
a | 7
b | 1
Run Code Online (Sandbox Code Playgroud)

我想添加一个包含每个x值的行数的列,如下所示:

x | y | n
--+---+---
a | 5 | 3
a | 8 | 3
a | 7 | 3
b | 1 | 1
Run Code Online (Sandbox Code Playgroud)

在dplyr中,我只想说:

import(tidyverse)

df <- read_csv("...")
df %>%
    group_by(x) %>%
    mutate(n = n()) %>%
    ungroup()
Run Code Online (Sandbox Code Playgroud)

就是这样.如果我想按行数总结一下,我可以在PySpark中做一些简单的事情:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .count() \
    .show()
Run Code Online (Sandbox Code Playgroud)

而且我认为我明白这withColumn相当于dplyr的mutate.但是,当我执行以下操作时,PySpark告诉我withColumn没有为groupBy数据定义:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .withColumn("n", count("x")) \
    .show()
Run Code Online (Sandbox Code Playgroud)

在短期内,我可以简单地创建包含计数的第二个数据帧并将其连接到原始数据帧.但是,在大型表格的情况下,这似乎会变得效率低下.实现这一目标的规范方法是什么?

pau*_*ult 17

执行操作时groupBy(),必须先指定聚合,然后才能显示结果.例如:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+
Run Code Online (Sandbox Code Playgroud)

在这里,我曾经alias()重命名该列.但这只会返回每组一行.如果您想要附加计数的所有行,您可以使用以下命令执行此操作Window:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+
Run Code Online (Sandbox Code Playgroud)

或者,如果您对SQL更熟悉,可以将数据帧注册为临时表,并利用它pyspark-sql来执行相同的操作:

df.registerTempTable('table')
sqlCtx.sql(
    'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+
Run Code Online (Sandbox Code Playgroud)


wie*_*u_p 10

作为@pault 附录

import pyspark.sql.functions as F

...

(df
.groupBy(F.col('x'))
.agg(F.count('x').alias('n'))
.show())

#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+
Run Code Online (Sandbox Code Playgroud)

享受


小智 5

我发现我们可以更接近 tidyverse 示例:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.withColumn('n', f.count('x').over(w)).sort('x', 'y').show()
Run Code Online (Sandbox Code Playgroud)