追加到 PySpark 数组列

Los*_*ssa 7 arrays append apache-spark apache-spark-sql pyspark

我想检查列值是否在某些边界内。如果不是,我将向数组列“F”附加一些值。这是我到目前为止的代码:

df = spark.createDataFrame(
    [
        (1, 56), 
        (2, 32),
        (3, 99)
    ],
    ['id', 'some_nr'] 
)

df = df.withColumn( "F", F.lit( None ).cast( types.ArrayType( types.ShortType( ) ) ) )

def boundary_check( val ):
  if (val > 60) | (val < 50):
    return 1

udf  = F.udf( lambda x: boundary_check( x ) ) 

df =  df.withColumn("F", udf(F.col("some_nr")))
display(df)
Run Code Online (Sandbox Code Playgroud)

但是,我不知道如何附加到数组。目前,如果我对 df 执行另一次边界检查,它将简单地覆盖“F”中以前的值...

Nap*_*rty 7

看看这里array_union的函数pyspark.sql.functionshttps://spark.apache.org/docs/latest/api/python/pyspark.sql.html ?highlight=join#pyspark.sql.functions.array_union

这样您就可以避免使用udf,从而消除 Spark 并行化的任何好处。代码看起来像这样:

from pyspark.context import SparkContext
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql import Row
import pyspark.sql.functions as f


conf = SparkConf()
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="a", c3=10),
                            Row(c1=["b", "a", "c"], c2="d", c3=20)])
df.show()
+---------+---+---+
|       c1| c2| c3|
+---------+---+---+
|[b, a, c]|  a| 10|
|[b, a, c]|  d| 20|
+---------+---+---+

df.withColumn(
    "output_column", 
    f.when(f.col("c3") > 10, 
           f.array_union(df.c1, f.array(f.lit("1"))))
     .otherwise(f.col("c1"))
).show()
+---------+---+---+-------------+
|       c1| c2| c3|output_column|
+---------+---+---+-------------+
|[b, a, c]|  a| 10|    [b, a, c]|
|[b, a, c]|  d| 20| [b, a, c, 1]|
+---------+---+---+-------------+
Run Code Online (Sandbox Code Playgroud)

作为旁注,这作为逻辑联合工作,因此如果您想附加一个值,您需要确保该值是唯一的,以便它始终被添加。否则,请查看array functions此处的其他内容: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html ?highlight=join#pyspark.sql.functions.array

注意:您的 Spark 需要是>2.4大多数数组函数的版本。

编辑(根据评论中的要求):

withColumn方法仅允许您一次处理一列,因此您需要使用一个新的withColumn,最好为两个withColumn查询预先定义逻辑语句。

logical_gate = (f.col("c3") > 10)

(
    df.withColumn(
        "output_column", 
        f.when(logical_gate, 
               f.array_union(df.c1, f.array(f.lit("1"))))
         .otherwise(f.col("c1")))
      .withColumn(
        "c3",
        f.when(logical_gate,
               f.lit(None))
         .otherwise(f.col("c3")))
      .show()
)
+---------+---+----+-------------+
|       c1| c2|  c3|output_column|
+---------+---+----+-------------+
|[b, a, c]|  a|  10|    [b, a, c]|
|[b, a, c]|  d|null| [b, a, c, 1]|
+---------+---+----+-------------+
Run Code Online (Sandbox Code Playgroud)