将数据帧中的矢量列转换回数组列

ulr*_*ich 6 apache-spark apache-spark-mllib

我有一个带有两列的数据帧,其中一列(称为dist)是一个密集向量.如何将其转换回整数数组列.

+---+-----+
| id| dist| 
+---+-----+
|1.0|[2.0]|
|2.0|[4.0]|
|3.0|[6.0]|
|4.0|[8.0]|
+---+-----+
Run Code Online (Sandbox Code Playgroud)

我尝试使用以下udf的几个变体,但它返回一个类型不匹配错误

val toInt4 = udf[Int, Vector]({ (a) => (a)})  

val result = df.withColumn("dist", toDf4(df("dist"))).select("dist")
Run Code Online (Sandbox Code Playgroud)

pwb*_*103 10

我挣扎了一段时间才得到@ThomasLuechtefeld工作的答案.但是遇到了这个非常令人沮丧的错误:

org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(features_scaled)' due to data type mismatch: argument 1 requires vector type, however, '`features_scaled`' is of vector type.
Run Code Online (Sandbox Code Playgroud)

结果我需要从ml包而不是mllib包导入DenseVector.

所以这对我有用:

import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.functions._

val vectorToColumn = udf{ (x:DenseVector, index: Int) => x(index) }
myDataframe.withColumn("clusters_scaled",vectorToColumn(col("features_scaled"),lit(0)))
Run Code Online (Sandbox Code Playgroud)

是的,唯一的区别是第一行.这绝对应该是一个评论,但我没有声誉.抱歉!

  • 您还可以考虑执行val vectorToColumn = udf {(x:org.apache.spark.ml.linalg.Vector,索引:Int)=> x(index)}`,因为它将能够处理密集和稀疏向量。 (2认同)

Dan*_*bos 5

我认为最简单的方法是转到RDD API,然后返回.

import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.sql.DataFrame
import org.apache.spark.rdd.RDD
import sqlContext._

// The original data.
val input: DataFrame =
  sc.parallelize(1 to 4)
    .map(i => i.toDouble -> new DenseVector(Array(i.toDouble * 2)))
    .toDF("id", "dist")

// Turn it into an RDD for manipulation.
val inputRDD: RDD[(Double, DenseVector)] =
  input.map(row => row.getAs[Double]("id") -> row.getAs[DenseVector]("dist"))

// Change the DenseVector into an integer array.
val outputRDD: RDD[(Double, Array[Int])] =
  inputRDD.mapValues(_.toArray.map(_.toInt))

// Go back to a DataFrame.
val output = outputRDD.toDF("id", "dist")
output.show
Run Code Online (Sandbox Code Playgroud)

你得到:

+---+----+
| id|dist|
+---+----+
|1.0| [2]|
|2.0| [4]|
|3.0| [6]|
|4.0| [8]|
+---+----+
Run Code Online (Sandbox Code Playgroud)