Aar*_*sek 7 apache-spark apache-spark-sql pyspark
在pyspark中,我有一个可变长度的double数组,我希望找到其均值。但是,平均值函数需要单个数字类型。
有没有一种方法可以找到一个数组的平均值而不分解该数组?我有几个不同的数组,我希望能够执行以下操作:
df.select(col("Segment.Points.trajectory_points.longitude"))
Run Code Online (Sandbox Code Playgroud)
DataFrame [经度:数组]
df.select(avg(col("Segment.Points.trajectory_points.longitude"))).show()
Run Code Online (Sandbox Code Playgroud)
Run Code Online (Sandbox Code Playgroud)org.apache.spark.sql.AnalysisException: cannot resolve 'avg(Segment.Points.trajectory_points.longitude)' due to data type mismatch: function average requires numeric types, not ArrayType(DoubleType,true);;
如果我有3个具有以下数组的唯一记录,我希望将这些值的平均值作为输出。这将是3个平均经度值。
输入:
[Row(longitude=[-80.9, -82.9]),
Row(longitude=[-82.92, -82.93, -82.94, -82.96, -82.92, -82.92]),
Row(longitude=[-82.93, -82.93])]
Run Code Online (Sandbox Code Playgroud)
输出:
-81.9,
-82.931,
-82.93
Run Code Online (Sandbox Code Playgroud)
我正在使用Spark版本2.1.3。
爆炸解决方案:
因此,我已经通过爆炸实现了这一目标,但我希望避免这一步。这就是我所做的
from pyspark.sql.functions import col
import pyspark.sql.functions as F
longitude_exp = df.select(
col("ID"),
F.posexplode("Segment.Points.trajectory_points.longitude").alias("pos", "longitude")
)
longitude_reduced = long_exp.groupBy("ID").agg(avg("longitude"))
Run Code Online (Sandbox Code Playgroud)
这成功地取了意思。但是,由于我将在几列中执行此操作,因此必须将同一DF爆炸几次。我将继续努力,以找到一种更清洁的方式来完成此任务。
在您的情况下,您的选择是 useexplode或udf. 正如您所指出的,这explode是不必要的昂贵。因此, audf是要走的路。
您可以编写自己的函数来获取数字列表的平均值,或者只是捎带numpy.mean. 如果使用numpy.mean,则必须将结果转换为 a float(因为 spark 不知道如何处理numpy.float64s)。
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
array_mean = udf(lambda x: float(np.mean(x)), FloatType())
df.select(array_mean("longitude").alias("avg")).show()
#+---------+
#| avg|
#+---------+
#| -81.9|
#|-82.93166|
#| -82.93|
#+---------+
Run Code Online (Sandbox Code Playgroud)
在最近的 Spark 版本(2.4 或更高版本)中,最有效的解决方案是使用aggregate高阶函数:
from pyspark.sql.functions import expr
query = """aggregate(
`{col}`,
CAST(0.0 AS double),
(acc, x) -> acc + x,
acc -> acc / size(`{col}`)
) AS `avg_{col}`""".format(col="longitude")
df.selectExpr("*", query).show()
Run Code Online (Sandbox Code Playgroud)
+--------------------+------------------+
| longitude| avg_longitude|
+--------------------+------------------+
| [-80.9, -82.9]| -81.9|
|[-82.92, -82.93, ...|-82.93166666666667|
| [-82.93, -82.93]| -82.93|
+--------------------+------------------+
Run Code Online (Sandbox Code Playgroud)
另请参阅Spark Scala row-wise average by processing null
| 归档时间: |
|
| 查看次数: |
530 次 |
| 最近记录: |