如何使用Scala运行带有分类功能集的Spark决策树?

Cli*_*der 10 tree scala categorical-data apache-spark apache-spark-mllib

我有一个功能集与相应的categoricalFeaturesInfo:Map [Int,Int].然而,对于我的生活,我无法弄清楚我应该如何使DecisionTree类工作.它不会接受任何内容,而是LabeledPoint作为数据.但是,LabeledPoint需要(double,vector),其中向量需要双精度数.

val LP = featureSet.map(x => LabeledPoint(classMap(x(0)),Vectors.dense(x.tail)))

// Run training algorithm to build the model
val maxDepth: Int = 3
val isMulticlassWithCategoricalFeatures: Boolean = true
val numClassesForClassification: Int = countPossibilities(labelCol) 
val model = DecisionTree.train(LP, Classification, Gini, isMulticlassWithCategoricalFeatures, maxDepth, numClassesForClassification,categoricalFeaturesInfo)
Run Code Online (Sandbox Code Playgroud)

我得到的错误:

scala> val LP = featureSet.map(x => LabeledPoint(classMap(x(0)),Vectors.dense(x.tail)))
<console>:32: error: overloaded method value dense with alternatives:
  (values: Array[Double])org.apache.spark.mllib.linalg.Vector <and>
  (firstValue: Double,otherValues: Double*)org.apache.spark.mllib.linalg.Vector
 cannot be applied to (Array[String])
       val LP = featureSet.map(x => LabeledPoint(classMap(x(0)),Vectors.dense(x.tail)))
Run Code Online (Sandbox Code Playgroud)

到目前为止我的资源: 树配置, 决策树, 标记点

小智 21

您可以先将类别转换为数字,然后加载数据,就像所有要素都是数字一样.

当您在Spark中构建决策树模型时,您只需要通过指定Map[Int, Int]()从特征索引到其arity 的映射来告诉spark哪些要素是分类的,以及要素的arity(该要素的不同类别的数量).

例如,如果您有以下数据:

1,a,add
2,b,more
1,c,thinking
3,a,to
1,c,me
Run Code Online (Sandbox Code Playgroud)

您可以先将数据转换为数字格式,如下所示:

1,0,0
2,1,1
1,2,2
3,0,3
1,2,4
Run Code Online (Sandbox Code Playgroud)

在该格式中,您可以将数据加载到Spark.然后,如果你想告诉Spark第二列和第三列是分类的,你应该创建一个地图:

categoricalFeaturesInfo = Map[Int, Int]((1,3),(2,5))
Run Code Online (Sandbox Code Playgroud)

地图告诉我们索引1的特征具有arity 3,而索引2的特征具有artity 5.当我们构建决策树模型并将该地图作为训练函数的参数传递时,它们将被视为分类:

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
Run Code Online (Sandbox Code Playgroud)

  • 关于这一点的一个尴尬和看似脆弱的事实是,您必须将类别编号存储在LabeledPoint中作为Double类型. (4认同)

小智 0

您需要确认数组 x 的类型。从错误日志来看,数组x中的项目是spark不支持的字符串。当前的 Spark Vector 只能用 Double 填充。