如何将具有SparseVector列的RDD转换为具有列为Vector的DataFrame

Ora*_*uez 13 apache-spark apache-spark-sql pyspark apache-spark-ml apache-spark-mllib

我有一个带有元组值的RDD(String,SparseVector),我想使用RDD创建一个DataFrame.获取(label:string,features:vector)DataFrame,它是大多数ml算法库所需的Schema.我知道可以这样做,因为 当给定DataFrame的features列时,HashingTF ml Library会输出一个向量.

temp_df = sqlContext.createDataFrame(temp_rdd, StructType([
        StructField("label", DoubleType(), False),
        StructField("tokens", ArrayType(StringType()), False)
    ]))

#assumming there is an RDD (double,array(strings))

hashingTF = HashingTF(numFeatures=COMBINATIONS, inputCol="tokens", outputCol="features")

ndf = hashingTF.transform(temp_df)
ndf.printSchema()

#outputs 
#root
#|-- label: double (nullable = false)
#|-- tokens: array (nullable = false)
#|    |-- element: string (containsNull = true)
#|-- features: vector (nullable = true)
Run Code Online (Sandbox Code Playgroud)

所以我的问题是,我能以某种方式将(String,SparseVector)的RDD转换为(String,vector)的DataFrame.我试着平常,sqlContext.createDataFrame但没有DataType符合我的需求.

df = sqlContext.createDataFrame(rdd,StructType([
        StructField("label" , StringType(),True),
        StructField("features" , ?Type(),True)
    ]))
Run Code Online (Sandbox Code Playgroud)

zer*_*323 19

你必须在VectorUDT这里使用:

# In Spark 1.x
# from pyspark.mllib.linalg import SparseVector, VectorUDT
from pyspark.ml.linalg import SparseVector, VectorUDT

temp_rdd = sc.parallelize([
    (0.0, SparseVector(4, {1: 1.0, 3: 5.5})),
    (1.0, SparseVector(4, {0: -1.0, 2: 0.5}))])

schema = StructType([
    StructField("label", DoubleType(), True),
    StructField("features", VectorUDT(), True)
])

temp_rdd.toDF(schema).printSchema()

## root
##  |-- label: double (nullable = true)
##  |-- features: vector (nullable = true)
Run Code Online (Sandbox Code Playgroud)

只是为了完整性Scala等效:

import org.apache.spark.sql.Row
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DoubleType, StructType}
// In Spark 1x.
// import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType

val schema = new StructType()
  .add("label", DoubleType)
   // In Spark 1.x
   //.add("features", new VectorUDT())
  .add("features",VectorType)

val temp_rdd: RDD[Row]  = sc.parallelize(Seq(
  Row(0.0, Vectors.sparse(4, Seq((1, 1.0), (3, 5.5)))),
  Row(1.0, Vectors.sparse(4, Seq((0, -1.0), (2, 0.5))))
))

spark.createDataFrame(temp_rdd, schema).printSchema

// root
// |-- label: double (nullable = true)
// |-- features: vector (nullable = true)
Run Code Online (Sandbox Code Playgroud)

  • 哇,我好久不停地寻找这个!几乎是幸福的呐喊:,)+1 (2认同)