PySpark:检索数据框内组的平均值和均值周围的值

Mat*_*ias 6 python sql window-functions apache-spark apache-spark-sql

我的原始数据以表格格式显示.它包含来自不同变量的观察.每次观察时都有变量名,时间戳和当时的值.

变量[string],Time [datetime],Value [float]

数据作为Parquet存储在HDFS中并加载到Spark Dataframe(df)中.从该数据帧.

现在我想为每个变量计算默认统计数据,如均值,标准差等.之后,一旦检索到Mean,我想过滤/计算那些紧邻Mean的变量值.

因此,我需要先得到每个变量的均值.这就是我使用GroupBy获取每个变量(不是整个数据集)的统计信息的原因.

df_stats = df.groupBy(df.Variable).agg( \
    count(df.Variable).alias("count"), \
    mean(df.Value).alias("mean"), \
    stddev(df.Value).alias("std_deviation"))
Run Code Online (Sandbox Code Playgroud)

通过每个变量的均值,我可以过滤那些围绕均值的特定变量的值(只是计数).因此,我需要该变量的所有观察值(值).这些值位于原始数据帧df中,而不是聚合/分组数据帧df_stats中.

创建统计数据

最后,我想要一个像聚合/分组df_stats这样的数据帧,并使用新列"count_around_mean".

我在考虑使用df_stats.map(...)或df_stats.join(df,df.Variable).但我被困在红色箭头上:(

问题:你怎么会意识到这一点?

临时解决方案:同时我正在使用基于您的想法的解决方案.但是stddev范围2和3的范围函数不起作用.它总是产生一个

AttributeError表示NullType没有_jvm

from pyspark.sql.window import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

w1 = Window().partitionBy("Variable")
w2 = Window.partitionBy("Variable").orderBy("Time")

def stddev_pop_w(col, w):
    #Built-in stddev doesn't support windowing
    return sqrt(avg(col * col).over(w) - pow(avg(col).over(w), 2))

def isInRange(value, mean, stddev, radius):
    try:
        if (abs(value - mean) < radius * stddev):
            return 1
        else:
            return 0
    except AttributeError:
        return -1

delta = col("Time").cast("long") - lag("Time", 1).over(w2).cast("long")
#f = udf(lambda (value, mean, stddev, radius): abs(value - mean) < radius * stddev, IntegerType())
f2 = udf(lambda value, mean, stddev: isInRange(value, mean, stddev, 2), IntegerType())
f3 = udf(lambda value, mean, stddev: isInRange(value, mean, stddev, 3), IntegerType())

df \
    .withColumn("mean", mean("Value").over(w1)) \
    .withColumn("std_deviation", stddev_pop_w(col("Value"), w1)) \
    .withColumn("delta", delta)
    .withColumn("stddev_2", f2("Value", "mean", "std_deviation")) \
    .withColumn("stddev_3", f3("Value", "mean", "std_deviation")) \
    .show(5, False)

#df2.withColumn("std_dev_3", stddev_range(col("Value"), w1)) \
Run Code Online (Sandbox Code Playgroud)

zer*_*323 4

火花2.0+

您可以替换stddev_pop_w为内置pyspark.sql.functions.stddev*函数之一。

火花<2.0

一般来说,不需要使用 join 进行聚合。相反,您可以计算统计数据,而无需使用窗口函数折叠行。假设您的数据如下所示:

import numpy as np
import pandas as pd
from pyspark.sql.functions import mean

n = 10000
k = 20

np.random.seed(100)

df = sqlContext.createDataFrame(pd.DataFrame({
    "id": np.arange(n),
    "variable": np.random.choice(k, n),
    "value": np.random.normal(0,  1, n)
}))
Run Code Online (Sandbox Code Playgroud)

您可以通过以下方式定义窗口分区variable

from pyspark.sql.window import Window

w = Window().partitionBy("variable")
Run Code Online (Sandbox Code Playgroud)

并计算统计数据如下:

from pyspark.sql.functions import avg, pow, sqrt

def stddev_pop_w(col, w):
    """Builtin stddev doesn't support windowing
    You can easily implement sample variant as well
    """
    return sqrt(avg(col * col).over(w) - pow(avg(col).over(w), 2))


(df
    .withColumn("stddev", stddev_pop_w(col("value"), w))
    .withColumn("mean", avg("value").over(w))
    .show(5, False))

## +---+--------------------+--------+------------------+--------------------+
## |id |value               |variable|stddev            |mean                |
## +---+--------------------+--------+------------------+--------------------+
## |47 |0.77212446947439    |0       |1.0103781346123295|0.035316745261099715|
## |60 |-0.931463439483327  |0       |1.0103781346123295|0.035316745261099715|
## |86 |1.0199074337552294  |0       |1.0103781346123295|0.035316745261099715|
## |121|-1.619408643898953  |0       |1.0103781346123295|0.035316745261099715|
## |145|-0.16065930935765935|0       |1.0103781346123295|0.035316745261099715|
## +---+--------------------+--------+------------------+--------------------+
## only showing top 5 rows
Run Code Online (Sandbox Code Playgroud)

只是为了比较聚合与连接:

from pyspark.sql.functions import stddev, avg, broadcast

df.join(
    broadcast(df.groupBy("variable").agg(avg("value"), stddev("value"))),
    ["variable"]
)
Run Code Online (Sandbox Code Playgroud)