比较 Scala Spark 中的两个数组列

Ven*_*thi 2 scala apache-spark array-column

我有一个下面给出的格式的数据框。

movieId1 | genreList1              | genreList2
--------------------------------------------------
1        |[Adventure,Comedy]       |[Adventure]
2        |[Animation,Drama,War]    |[War,Drama]
3        |[Adventure,Drama]        |[Drama,War]
Run Code Online (Sandbox Code Playgroud)

并尝试创建另一个标志列,显示流派列表 2 是否是流派列表 1 的子集

movieId1 | genreList1              | genreList2        | Flag
---------------------------------------------------------------
1        |[Adventure,Comedy]       | [Adventure]       |1
2        |[Animation,Drama,War]    | [War,Drama]       |1
3        |[Adventure,Drama]        | [Drama,War]       |0
Run Code Online (Sandbox Code Playgroud)

我试过这个

def intersect_check(a: Array[String], b: Array[String]): Int = {
  if (b.sameElements(a.intersect(b))) { return 1 } 
  else { return 2 }
}

def intersect_check_udf =
  udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))

data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))
Run Code Online (Sandbox Code Playgroud)

但这会引发org.apache.spark.SparkException: Failed to execute user defined function.错误。关于如何解决这个问题的任何想法。PS:上面的函数 ( intersect_check) 适用于Arrays。

mto*_*oto 5

我们可以定义一个udf计算intersectionArray列之间的长度并检查它是否等于第二列的长度。如果是,则第二个数组是第一个数组的子集。

此外,您udf需要的输入是 class WrappedArray[String],而不是Array[String]

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

val same_elements = udf { (a: WrappedArray[String], 
                           b: WrappedArray[String]) => 
  if (a.intersect(b).length == b.length){ 1 }else{ 0 }  
}

df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
  .show(truncate = false)
+--------+-----------------------+------------+----+
|movieId1|genreList1             |genreList2  |test|
+--------+-----------------------+------------+----+
|1       |[Adventure, Comedy]    |[Adventure] |1   |
|2       |[Animation, Drama, War]|[War, Drama]|1   |
|3       |[Adventure, Drama]     |[Drama, War]|0   |
+--------+-----------------------+------------+----+
Run Code Online (Sandbox Code Playgroud)

数据

val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
              (2,Array("Animation","Drama","War"), Array("War","Drama")),
              (3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")
Run Code Online (Sandbox Code Playgroud)