从PySpark阵列列中删除重复项

Tho*_*mas 1 python apache-spark apache-spark-sql pyspark

我有一个PySpark数据框,其中包含一ArrayType(StringType())列。该列包含数组中需要删除的重复字符串。例如,一行条目可能看起来像[milk, bread, milk, toast]。假设我的数据框已命名df,我的列已命名arraycol。我需要类似的东西:

df = df.withColumn("arraycol_without_dupes", F.remove_dupes_from_array("arraycol"))
Run Code Online (Sandbox Code Playgroud)

我的直觉是对此有一个简单的解决方案,但是在浏览stackoverflow 15分钟后,我发现没有比分解该列,删除整个数据帧上的重复项然后再进行分组更好的了。目前已经得到了成为一个更简单的方法,我只是没想到吧?

我正在使用Spark版本'2.3.1'。

pau*_*ult 5

对于pyspark 2.4+版本,您可以使用pyspark.sql.functions.array_distinct

from pyspark.sql.functions import array_distinct
df = df.withColumn("arraycol_without_dupes", array_distinct("arraycol"))
Run Code Online (Sandbox Code Playgroud)

对于较旧的版本,您可以使用explode+ groupBy和使用API​​函数来执行此操作collect_set,但是udf在这里a 可能更有效:

from pyspark.sql.functions import udf

remove_dupes_from_array = udf(lambda row: list(set(row)), ArrayType(StringType()))
df = df.withColumn("arraycol_without_dupes", remove_dupes_from_array("arraycol"))
Run Code Online (Sandbox Code Playgroud)