使用 Scala 将 Spark DataFrame 中某个 DataType 的所有列的 DataType 转换为另一种 DataType

vam*_*msi 1 scala dataframe apache-spark

我有一个包含 100 多个列的 Spark DataFrame。在此 DataFrame 中,我想将所有DoubleType列转换为DecimalType(18,5). 我可以使用以下方式将一种特定数据类型转换为另一种特定数据类型:

def castAllTypedColumnsTo(inputDF: DataFrame, sourceType: DataType) = {

    val targetType = sourceType match {
      case DoubleType => DecimalType(18,5)
      case _ => sourceType
    }

    inputDF.schema.filter(_.dataType == sourceType).foldLeft(inputDF) {
      case (acc, col) => acc.withColumn(col.name, inputDF(col.name).cast(targetType))
    }
  }

val inputDF = Seq((1,1.0),(2,2.0)).toDF("id","amount")

inputDF.printSchema()

root
 |-- id: integer (nullable = true)
 |-- amount: double (nullable = true)

val finalDF : DataFrame = castAllTypedColumnsTo(inputDF, DoubleType)

finalDF.printSchema()

root
 |-- id: integer (nullable = true)
 |-- amount: decimal(18,5) (nullable = true)

Run Code Online (Sandbox Code Playgroud)

在这里,我过滤掉DoubleType列并转换为DecimalType(18,5). 假设我想转换另一个数据类型,如何在不将数据类型作为输入参数传递的情况下实现该场景。

我期待如下所示的内容:

def convertDataType(inputDF: DataFrame): DataFrame = {

   inputDF.dtypes.map{
       case (colName, colType) => (colName, colType match {
          case "DoubleType" => DecimalType(18,5).toString
          case _ => colType
          })
   }
   //finalDF to be created with new DataType.
}

val finalDF = convertDataType(inputDF)
Run Code Online (Sandbox Code Playgroud)

有人可以帮我处理这种情况吗?

Sri*_*vas 5

尝试下面的代码。

scala> :paste
// Entering paste mode (ctrl-D to finish)

import org.apache.spark.sql.types.StructField

def castAllTypedColumnsTo(field: StructField) = field.dataType.typeName match {
      case "double" => col(field.name).cast("decimal(18,5)")
      case "integer" => col(field.name).cast("integer")
      case _ => col(field.name)
}
Run Code Online (Sandbox Code Playgroud)
inputDF
.select(inputDF.schema.map(castAllTypedColumnsTo):_*)
.show(false)

// Exiting paste mode, now interpreting.

+---+-------+
|id |amount |
+---+-------+
|1  |1.00000|
|2  |2.00000|
+---+-------+

import org.apache.spark.sql.types.StructField
castAllTypedColumnsTo: (field: org.apache.spark.sql.types.StructField)org.apache.spark.sql.Column

scala>
Run Code Online (Sandbox Code Playgroud)