数据框列上的火花 scala 模式匹配

Ros*_*n A 5 regex scala apache-spark

我来自 R 背景。我可以在 R 中的 Dataframe col 上实现模式搜索。但现在很难在 spark scala 中做到这一点。任何帮助,将不胜感激

问题陈述被分解成细节只是为了适当地描述它 DF :

           Case                      Freq
            135322                     265
     183201,135322                      36
     135322,135322                      18
     135322,121200                      11
     121200,135322                       8
     112107,112107                       7
     183201,135322,135322                4
     112107,135322,183201,121200,80000   2
Run Code Online (Sandbox Code Playgroud)

我正在寻找模式搜索 UDF,它会返回模式的所有匹配项,然后返回第二列中的相应 Freq 值。

例如:对于模式135322,我想找出第一个 col Case 中的所有匹配项。它应该从 Freq col 返回相应的 Freq 编号。喜欢265,36,18,11,8,4,2

对于模式,112107,112107它应该仅仅7因为有一个匹配模式而返回。

这就是最终结果的样子

          Case                           Freq   results
            135322                       265    256+36+18+11+8+4+2
     183201,135322                        36    36+4+2
     135322,135322                        18    18+4
     135322,121200                        11    11+2
     121200,135322                         8    8+2
     112107,112107                         7    7
     183201,135322,135322                  4    4
     112107,135322,183201,121200,80000     2    2
Run Code Online (Sandbox Code Playgroud)

到目前为止我尝试过的:

val text= DF.select("case").collect().map(_.getString(0)).mkString("|")

 //search function for pattern search

 val valsum = udf((txt: String, pattern : String)=> { 
    txt.split("\\|").count(_.contains(pattern)) 
  } )

 //apply the UDF on the first col 
 val dfValSum = DF.withColumn("results", valsum( lit(text),DF("case")))  
Run Code Online (Sandbox Code Playgroud)

San*_*ver 0

这个有效

import common.Spark.sparkSession
import java.util.regex.Pattern
import util.control.Breaks._

object playground extends App {

  import org.apache.spark.sql.functions._

  val pattern = "135322,121200" // Pattern you want to search for

  // udf declaration
  val coder: ((String, String) => Boolean) = (caseCol: String, pattern: String) =>
    {
      var result = true
      val splitPattern = pattern.split(",")
      val splitCaseCol = caseCol.split(",")
      var foundAtIndex = -1

      for (i <- 0 to splitPattern.length - 1) {
        breakable {
          for (j <- 0 to splitCaseCol.length - 1) {
            if (j > foundAtIndex) {
              println(splitCaseCol(j))
              if (splitCaseCol(j) == splitPattern(i)) {
                result = true
                foundAtIndex = j
                break
              } else result = false
            } else result = false
          }
        }
      }
      println(caseCol, result)
      (result)
    }

  // registering the udf  
  val udfFilter = udf(coder)

  //reading the input file
  val df = sparkSession.read.option("delimiter", "\t").option("header", "true").csv("output.txt")

  //calling the function and aggregating
  df.filter(udfFilter(col("Case"), lit(pattern))).agg(lit(pattern), sum("Freq")).toDF("pattern","sum").show

}
Run Code Online (Sandbox Code Playgroud)

如果输入是

135322,121200

输出是

+-------------+----+
|      pattern| sum|
+-------------+----+
|135322,121200|13.0|
+-------------+----+
Run Code Online (Sandbox Code Playgroud)

如果输入是

135322,135322

输出是

+-------------+----+
|      pattern| sum|
+-------------+----+
|135322,135322|22.0|
+-------------+----+
Run Code Online (Sandbox Code Playgroud)