use*_*816 2 scala machine-learning apache-spark
我在Apache Spark ML(版本2.1.0)中使用NaiveBayes多项式分类器来预测某些文本类别。
问题是如何在没有训练有素的DataFrame的情况下将预测标签(0.0、1.0、2.0)转换为字符串。
我知道可以使用IndexToString,但是只有在训练和预测都在同一时间的情况下,它才有用。但是,就我而言,它是独立的工作。
代码如下所示:
1)TrainingModel.scala:训练模型并将模型保存在文件中。
2)CategoryPrediction.scala:从文件中加载训练后的模型并对测试数据进行预测。
请提出解决方案:
TrainingModel.scala
val trainData: Dataset[LabeledRecord] = spark.read.option("inferSchema", "false")
.schema(schema).csv("trainingdata1.csv").as[LabeledRecord]
val labelIndexer = new StringIndexer().setInputCol("category").setOutputCol("label").fit(trainData).setHandleInvalid("skip")
val tokenizer = new RegexTokenizer().setInputCol("text").setOutputCol("words")
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(1000)
val rf = new NaiveBayes().setLabelCol("label").setFeaturesCol("features").setModelType("multinomial")
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf))
val model = pipeline.fit(trainData)
model.write.overwrite().save("naivebayesmodel");
Run Code Online (Sandbox Code Playgroud)
CategoryPrediction.scala
val testData: Dataset[PredictLabeledRecord] = spark.read.option("inferSchema", "false")
.schema(predictSchema).csv("testingdata.csv").as[PredictLabeledRecord]
val model = PipelineModel.load("naivebayesmodel")
val predictions = model.transform(testData)
// val labelConverter = new IndexToString()
// .setInputCol("prediction")
// .setOutputCol("predictedLabelString")
// .setLabels(trainDataFrameIndexer.labels)
predictions.select("prediction", "text").show(false)
Run Code Online (Sandbox Code Playgroud)
trainingdata1.csv
category,text
Drama,"a b c d e spark"
Action,"b d"
Horror,"spark f g h"
Thriller,"hadoop mapreduce"
Run Code Online (Sandbox Code Playgroud)
testingdata.csv
text
"a b c d e spark"
"spark f g h"
Run Code Online (Sandbox Code Playgroud)
添加一个转换器,将预测类别转换回管道中的标签,如下所示:
val categoryConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("category")
.setLabels(labelIndexer.labels)
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf, categoryConverter))
Run Code Online (Sandbox Code Playgroud)
这将获取预测,并使用labelIndexer将其转换回标签。
| 归档时间: |
|
| 查看次数: |
984 次 |
| 最近记录: |