PySpark 中数组列的配对组合

Zyg*_*ygD 1 python arrays combinations apache-spark pyspark

与这个问题(Scala)类似,但我需要 PySpark 中的组合(数组列的配对组合)。

输入示例:

df = spark.createDataFrame(
    [([0, 1],),
     ([2, 3, 4],),
     ([5, 6, 7, 8],)],
    ['array_col'])
Run Code Online (Sandbox Code Playgroud)

预期输出:

+------------+------------------------------------------------+
|array_col   |out                                             |
+------------+------------------------------------------------+
|[0, 1]      |[[0, 1]]                                        |
|[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+
Run Code Online (Sandbox Code Playgroud)

Zyg*_*ygD 5

原生 Spark方法。我已将此答案翻译为 PySpark。

Python 3.8+(海象:=运算符"array_col"在此脚本中重复多次):

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                c:="array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
            )),
            lambda x: F.array(x["0"], x[c])
        ),
        lambda x: x[0] < x[1]
    )
)
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col   |out                                             |
# +------------+------------------------------------------------+
# |[0, 1]      |[[0, 1]]                                        |
# |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+
Run Code Online (Sandbox Code Playgroud)

没有海象运算符的替代方案:

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                "array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size("array_col")), "array_col")
            )),
            lambda x: F.array(x["0"], x["array_col"])
        ),
        lambda x: x[0] < x[1]
    )
)
Run Code Online (Sandbox Code Playgroud)

Spark 2.4+ 的替代方案

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.expr("""
        filter(
            transform(
                flatten(transform(
                    array_col,
                    x -> arrays_zip(array_repeat(x, size(array_col)), array_col)
                )),
                x -> array(x["0"], x["array_col"])
            ),
            x -> x[0] < x[1]
        )
    """)
)
Run Code Online (Sandbox Code Playgroud)