PySpark中可变列数的总和

Ser*_*gei 2 python apache-spark apache-spark-sql pyspark

我有一个像这样的Spark DataFrame:

+-----+--------+-------+-------+-------+-------+-------+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|
+-----+--------+-------+-------+-------+-------+-------+
|  Cat|       1|      1|      2|      3|      4|      5|
|  Dog|       2|      1|      2|      3|      4|      5|
|Mouse|       4|      1|      2|      3|      4|      5|
|  Fox|       5|      1|      2|      3|      4|      5|
+-----+--------+-------+-------+-------+-------+-------+
Run Code Online (Sandbox Code Playgroud)

您可以使用下一个代码重现它:

data = [('Cat', 1, 1, 2, 3, 4, 5),
        ('Dog', 2, 1, 2, 3, 4, 5),
        ('Mouse', 4, 1, 2, 3, 4, 5),
        ('Fox', 5, 1, 2, 3, 4, 5)]
columns = ['Type', 'Criteria', 'Value#1', 'Value#2', 'Value#3', 'Value#4', 'Value#5']
df = spark.createDataFrame(data, schema=columns)
df.show()
Run Code Online (Sandbox Code Playgroud)

我的任务是添加Total列,它是所有Value列的总和,而不是#no,然后是此行的Criteria.

在这个例子中:

  • 对于行'Cat':标准是1,这样Total仅仅是Value#1.
  • 对于行'Dog':标准是2,和Total的总和也是.Value#1Value#2
  • 对于行'Fox':标准是5,Total所有列的总和(Value#1通过Value#5)也是如此.

结果应如下所示:

+-----+--------+-------+-------+-------+-------+-------+-----+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|Total|
+-----+--------+-------+-------+-------+-------+-------+-----+
|  Cat|       1|      1|      2|      3|      4|      5|    1|
|  Dog|       2|      1|      2|      3|      4|      5|    3|
|Mouse|       4|      1|      2|      3|      4|      5|   10|
|  Fox|       5|      1|      2|      3|      4|      5|   15|
+-----+--------+-------+-------+-------+-------+-------+-----+
Run Code Online (Sandbox Code Playgroud)

我可以使用Python UDF来完成它,但我的数据集很大,而且由于序列化,Python UDF很慢.我正在寻找纯粹的Spark解决方案.

我正在使用PySpark和Spark 2.1

hi-*_*zir 5

您可以轻松地将解决方案调整为PySpark:计算列子集的行最大值并通过user6910411 添加到现有数据帧

from pyspark.sql.functions import col, when

total = sum([
    when(col("Criteria") >= i, col("Value#{}".format(i))).otherwise(0)
    for i in range(1, 6)
])

df.withColumn("total", total).show()

# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# |  Cat|       1|      1|      2|      3|      4|      5|    1|
# |  Dog|       2|      1|      2|      3|      4|      5|    3|
# |Mouse|       4|      1|      2|      3|      4|      5|   10|
# |  Fox|       5|      1|      2|      3|      4|      5|   15|
# +-----+--------+-------+-------+-------+-------+-------+-----+
Run Code Online (Sandbox Code Playgroud)

对于任意一组订单列,定义一个list:

cols = df.columns[2:]
Run Code Online (Sandbox Code Playgroud)

并将总数重新定义为:

total_ = sum([
    when(col("Criteria") > i, col(cols[i])).otherwise(0)
    for i in range(len(cols))
])

df.withColumn("total", total_).show()
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# |  Cat|       1|      1|      2|      3|      4|      5|    1|
# |  Dog|       2|      1|      2|      3|      4|      5|    3|
# |Mouse|       4|      1|      2|      3|      4|      5|   10|
# |  Fox|       5|      1|      2|      3|      4|      5|   15|
# +-----+--------+-------+-------+-------+-------+-------+-----+
Run Code Online (Sandbox Code Playgroud)

重要:

这里sum__builtin__.sum不是pyspark.sql.functions.sum.