检查另一个数组pyspark中存在的数组的所有元素

pri*_*iya 1 apache-spark apache-spark-sql pyspark

我有一个 df1 spark 数据框

id     transactions
1      [1, 2, 3, 5]
2      [1, 2, 3, 6]
3      [1, 2, 9, 8]
4      [1, 2, 5, 6]

root
 |-- id: int (nullable = true)
 |-- transactions: array (nullable = false)
     |-- element: int(containsNull = true)
 None
Run Code Online (Sandbox Code Playgroud)

我有一个 df2 spark 数据框

items   cost
  [1]    1.0
  [2]    1.0
 [2, 1]  2.0
 [6, 1]  2.0

root
 |-- items: array (nullable = false)
    |-- element: int (containsNull = true)
 |-- cost: int (nullable = true)
 None
Run Code Online (Sandbox Code Playgroud)

我想检查 items 列中的所有数组元素是否都在 transactions 列中。

第一行 ( [1, 2, 3, 5]) 包含[1],[2],[2, 1]来自项目列。因此,我需要总结它们相应的成本:1.0 + 1.0 + 2.0 = 4.0

我想要的输出是

id     transactions    score
1      [1, 2, 3, 5]   4.0
2      [1, 2, 3, 6]   6.0
3      [1, 2, 9, 8]   4.0
4      [1, 2, 5, 6]   6.0
Run Code Online (Sandbox Code Playgroud)

我尝试使用带有collect()/的循环,toLocalIterator但似乎效率不高。我会有大数据。

我认为创建这样的 UDF 将解决它。但它会引发错误。

from pyspark.sql.functions import udf
def containsAll(x,y):
  result =  all(elem in x  for elem in y)

  if result:
    print("Yes, transactions contains all items")    
  else :
    print("No")

 contains_udf = udf(containsAll)
 dataFrame.withColumn("result", 
 contains_udf(df2.items,df1.transactions)).show()
Run Code Online (Sandbox Code Playgroud)

或者有其他方法吗?

use*_*362 5

2.4 之前的有效 udf(注意它必须返回一些东西

from pyspark.sql.functions import udf

@udf("boolean")
def contains_all(x, y):
    if x is not None and y is not None:
        return set(y).issubset(set(x))
Run Code Online (Sandbox Code Playgroud)

在 2.4 或更高版本中,不需要 udf:

from pyspark.sql.functions import array_intersect, size

def contains_all(x, y):
    return size(array_intersect(x, y)) == size(y)
Run Code Online (Sandbox Code Playgroud)

用法:

from pyspark.sql.functions import col, sum as sum_, when

df1 = spark.createDataFrame(
   [(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
   ("id", "transactions")
)

df2 = spark.createDataFrame(
    [([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
    ("items", "cost")
)


(df1
    .crossJoin(df2).groupBy("id", "transactions")
    .agg(sum_(when(
        contains_all("transactions", "items"), col("cost")
    )).alias("score"))
    .show())
Run Code Online (Sandbox Code Playgroud)

结果:

+---+------------+-----+                                                        
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
+---+------------+-----+
Run Code Online (Sandbox Code Playgroud)

如果df2很小,则可能更喜欢将其用作局部变量:

items = sc.broadcast([
    (set(items), cost) for items, cost in df2.select("items", "cost").collect()
])

def score(y):
    @udf("double")
    def _(x):
        if x is not None:
            transactions = set(x)
            return sum(
                cost for items, cost in y.value 
                if items.issubset(transactions))
    return _


df1.withColumn("score", score(items)("transactions")).show()
Run Code Online (Sandbox Code Playgroud)
+---+------------+-----+                                                        
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
+---+------------+-----+
Run Code Online (Sandbox Code Playgroud)

最后可以爆炸并加入

items = sc.broadcast([
    (set(items), cost) for items, cost in df2.select("items", "cost").collect()
])

def score(y):
    @udf("double")
    def _(x):
        if x is not None:
            transactions = set(x)
            return sum(
                cost for items, cost in y.value 
                if items.issubset(transactions))
    return _


df1.withColumn("score", score(items)("transactions")).show()
Run Code Online (Sandbox Code Playgroud)
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
+---+------------+-----+
Run Code Online (Sandbox Code Playgroud)

然后将结果与 original 合并df1

df1.join(costs, ["id"])
Run Code Online (Sandbox Code Playgroud)

但这不是直接的解决方案,需要多次洗牌。它可能仍然比笛卡尔积 ( crossJoin) 更可取,但这将取决于实际数据。