从字符串文字中推断Spark DataType

sme*_*eeb 8 types scala introspection apache-spark spark-dataframe

我正在尝试编写一个Scala函数,它可以根据提供的输入字符串推断Spark DataTypes:

/**
 * Example:
 * ========
 * toSparkType("string")  =>    StringType
 * toSparkType("boolean") =>    BooleanType
 * toSparkType("date")    =>    DateType
 * etc.
 */
def toSparkType(inputType : String) : DataType = {
    var dt : DataType = null

    if(matchesStringRegex(inputType)) {
        dt = StringType
    } else if(matchesBooleanRegex(inputType)) {
        dt = BooleanType
    } else if(matchesDateRegex(inputType)) {
        dt = DateType
    } else if(...) {
        ...
    }

    dt
}
Run Code Online (Sandbox Code Playgroud)

我的目标是支持可用的大部分(如果不是全部的话)DataTypes.当我开始实现这个功能,我开始思考:" 星火/斯卡拉可能已经有一个助手/ util的方法,会为我做这件事. "毕竟,我知道我可以这样做:

var structType = new StructType()

structType.add("some_new_string_col", "string", true, Metadata.empty)
structType.add("some_new_boolean_col", "boolean", true, Metadata.empty)
structType.add("some_new_date_col", "date", true, Metadata.empty)
Run Code Online (Sandbox Code Playgroud)

并且Scala和/或Spark都会隐式地将我的"string"参数转换为StringType等等.所以我问:我可以使用Spark或Scala来帮助我实现转换器方法有什么魔力?

Sac*_*agi 15

Spark/Scala可能已经有了一个helper/util方法,可以为我做这个.

你是对的.Spark已经拥有自己的架构和数据类型推断代码,它用于从底层数据源(csv,json等)推断架构.所以你可以看看它实现自己的(实际的实现被标记为Spark私有,是绑定到RDD和内部类,所以它不能直接从Spark之外的代码中使用,但应该让你知道如何去做它.)

鉴于csv是平面类型(并且json可以具有嵌套结构),csv模式推断相对更直接,并且可以帮助您完成上面尝试实现的任务.所以我将解释csv推理是如何工作的(json推理只需要考虑可能的嵌套结构,但数据类型推断非常类似).

有了这个序幕,你要看的东西是CSVInferSchema对象.特别是,查看在整个RDD中infer获取RDD[Array[String]]和推断数组的每个元素的数据类型的方法.它的作用方式是 - 它将每个字段标记为开始,然后当它迭代下一行值()时,如果新的更具体,则将已推断为新的更新.这发生在这里:NullTypeArray[String]RDDDataTypeDataTypeDataType

val rootTypes: Array[DataType] =
      tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
Run Code Online (Sandbox Code Playgroud)

现在inferRowType 调用行 inferField中的每个字段.inferField 实现是你可能正在寻找的 - 它为特定字段推断到目前为止的类型,并将当前行的字段的字符串值作为参数.然后,它返回现有的推断类型,或者如果推断的新类型更具体,则返回新类型.

代码的相关部分如下:

typeSoFar match {
        case NullType => tryParseInteger(field, options)
        case IntegerType => tryParseInteger(field, options)
        case LongType => tryParseLong(field, options)
        case _: DecimalType => tryParseDecimal(field, options)
        case DoubleType => tryParseDouble(field, options)
        case TimestampType => tryParseTimestamp(field, options)
        case BooleanType => tryParseBoolean(field, options)
        case StringType => StringType
        case other: DataType =>
          throw new UnsupportedOperationException(s"Unexpected data type $other")
      }
Run Code Online (Sandbox Code Playgroud)

请注意,如果typeSoFar是NullType,那么它首先尝试将其解析为Integer但是tryParseIntegercall是对较低类型解析的调用链.因此,如果它无法将值解析为Integer,那么它将调用tryParseLong失败时将调用tryParseDecimal哪个失败将最终调用tryParseDoublewofwi tryParseTimestampwofwi tryParseBooleanwofwi stringType.

所以你可以使用几乎相似的逻辑来实现你的用例.(如果您不需要跨行合并,那么您只需tryParse*逐字实现所有方法并简单地调用tryParseInteger.无需编写自己的正则表达式.)

希望这可以帮助.


aln*_*lno 9

是的,当然Spark有你需要的魔力.

在Spark 2.x中它的CatalystSqlParser对象,在这里定义.

例如:

import org.apache.spark.sql.catalyst.parser.CatalystSqlParser

CatalystSqlParser.parseDataType("string") // StringType
CatalystSqlParser.parseDataType("int") // IntegerType
Run Code Online (Sandbox Code Playgroud)

等等.

但据我了解,它不是公共API的一部分,因此可能会在下一版本中发生变化,而不会发出任何警告.

所以你可以将你的方法实现为:

def toSparkType(inputType: String): DataType = CatalystSqlParser.parseDataType(inputType)
Run Code Online (Sandbox Code Playgroud)