如何在 Spark 中动态切片数组列?

har*_*ppu 3 python apache-spark apache-spark-sql pyspark

Spark 2.4 引入了新的 SQL 函数slice,可用于从数组列中提取一定范围的元素。我想根据一个整数列动态定义每行的范围,该列具有我想从该列中选择的元素数。

但是,简单地将列传递给 slice 函数会失败,该函数似乎需要整数作为起始值和结束值。有没有办法在不编写 UDF 的情况下做到这一点?

用一个例子来形象化这个问题:我有一个带有数组列的数据框,arr在每一行中都有一个看起来像['a', 'b', 'c']. 还有一个end_idx包含元素的列31并且2

+---------+-------+
|arr      |end_idx|
+---------+-------+
|[a, b, c]|3      |
|[a, b, c]|1      |
|[a, b, c]|2      |
+---------+-------+
Run Code Online (Sandbox Code Playgroud)

我尝试创建一个这样的新列arr_trimmed

+---------+-------+
|arr      |end_idx|
+---------+-------+
|[a, b, c]|3      |
|[a, b, c]|1      |
|[a, b, c]|2      |
+---------+-------+
Run Code Online (Sandbox Code Playgroud)

我希望这个代码与内容创建新列['a', 'b', 'c']['a']['a', 'b']

相反,我收到一个错误TypeError: Column is not iterable

Dav*_*rba 9

您可以通过传递 SQL 表达式来实现,如下所示:

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)"))

Run Code Online (Sandbox Code Playgroud)

这是整个工作示例:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]

df = spark.createDataFrame(l, ["arr", "end_idx"])

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)")).show(truncate=False)

+---------+-------+-----------+
|arr      |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3      |[a, b, c]  |
|[a, b, c]|1      |[a]        |
|[a, b, c]|2      |[a, b]     |
+---------+-------+-----------+
Run Code Online (Sandbox Code Playgroud)