scala Spark UDF 过滤器结构体数组

al3*_*uch 6 scala apache-spark apache-spark-sql

我有一个带有模式的数据框

root
 |-- x: Long (nullable = false)
 |-- y: Long (nullable = false)
 |-- features: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- score: double (nullable = true)
Run Code Online (Sandbox Code Playgroud)

例如我有数据

+--------------------+--------------------+------------------------------------------+
|                x   |              y     |       features                           |
+--------------------+--------------------+------------------------------------------+
|10                  |          9         |[["f1", 5.9], ["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["f4", 0.9], ["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["f5", 5.9], ["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["f1", 5.9], ["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+------------------------------------------+
Run Code Online (Sandbox Code Playgroud)

我想过滤具有特定前缀的功能,例如“ft”,所以最终我想要结果:

+--------------------+--------------------+-----------------------------+
|                x   |              y     |       features              |
+--------------------+--------------------+-----------------------------+
|10                  |          9         |[["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+-----------------------------+
Run Code Online (Sandbox Code Playgroud)

我没有使用 Spark 2.4+,所以我无法使用此处提供的解决方案:Spark (Scala) filter array of structs withoutexplode

我尝试使用UDF,但仍然不起作用。这是我的尝试。我定义一个UDF:

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }
)
Run Code Online (Sandbox Code Playgroud)

但如果我应用这个 UDF

df.withColumn("filtered", filterFeature($"features"))
Run Code Online (Sandbox Code Playgroud)

我得到了错误Schema for type org.apache.spark.sql.Row is not supported。我发现我无法Row从UDF返回。然后我尝试了

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, (StringType, DoubleType)
)
Run Code Online (Sandbox Code Playgroud)

然后我得到一个错误:

 error: type mismatch;
 found   : (org.apache.spark.sql.types.StringType.type, org.apache.spark.sql.types.DoubleType.type)
 required: org.apache.spark.sql.types.DataType
              }, (StringType, DoubleType)
                 ^
Run Code Online (Sandbox Code Playgroud)

我还尝试了一些答案所建议的案例类:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, FilteredFeature
)
Run Code Online (Sandbox Code Playgroud)

但我得到了:

 error: type mismatch;
 found   : FilteredFeature.type
 required: org.apache.spark.sql.types.DataType
              }, FilteredFeature
                 ^
Run Code Online (Sandbox Code Playgroud)

我试过:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, Seq[FilteredFeature]
)
Run Code Online (Sandbox Code Playgroud)

我有:

<console>:192: error: missing argument list for method apply in class GenericCompanion
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `apply _` or `apply(_)` instead of `apply`.
              }, Seq[FilteredFeature]
                    ^
Run Code Online (Sandbox Code Playgroud)

我试过:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, Seq[FilteredFeature](_)
)
Run Code Online (Sandbox Code Playgroud)

我有:

<console>:201: error: type mismatch;
 found   : Seq[FilteredFeature]
 required: FilteredFeature
              }, Seq[FilteredFeature](_)
                          ^
Run Code Online (Sandbox Code Playgroud)

这种情况我该怎么办?

Rap*_*oth 5

您有两个选择:

a) 向 UDF 提供一个模式,这让您返回Seq[Row]

b) 转换Seq[Row]SeqofTuple2或 case 类,那么您不需要提供模式(但是如果您使用元组,结构字段名称将会丢失!)

对于您的情况,我更喜欢选项 a)(对于具有许多字段的结构效果很好):

val schema = df.schema("features").dataType

val filterFeature = udf((features:Seq[Row]) => features.filter(_.getAs[String]("name").startsWith("ft")),schema)
Run Code Online (Sandbox Code Playgroud)