cix*_*cix 9 dataframe python-3.x apache-spark-sql pyspark
我有一个数据框 df。我已经在数据帧上执行了决策树分类算法。两列是执行算法时的标签和特征。该模型称为dtc. 如何在 pyspark 中创建混淆矩阵?
dtc = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label')
dtcModel = dtc.fit(train)
predictions = dtcModel.transform(test)
Run Code Online (Sandbox Code Playgroud)
from pyspark.mllib.linalg import Vectors
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.evaluation import MulticlassMetrics
preds = df.select(['label', 'features']) \
.df.map(lambda line: (line[1], line[0]))
metrics = MulticlassMetrics(preds)
# Confusion Matrix
print(metrics.confusionMatrix().toArray())```
Run Code Online (Sandbox Code Playgroud)
Qua*_*ats 13
在调用之前,您需要转换为 rdd 并映射到元组metrics.confusionMatrix().toArray()。
从官方文档来看,
\n\n\n类 pyspark.mllib.evaluation.MulticlassMetrics(predictionAndLabels) [来源]
\n多类分类评估器。
\n参数:predictionAndLabels \xe2\x80\x93 (预测,标签)对的 RDD。
\n
这是一个指导您的示例。
\n机器学习部分
\nimport pyspark.sql.functions as F\nfrom pyspark.ml.feature import VectorAssembler\nfrom pyspark.ml.classification import DecisionTreeClassifier\nfrom pyspark.mllib.evaluation import MulticlassMetrics\nfrom pyspark.sql.types import FloatType\n#Note the differences between ml and mllib, they are two different libraries.\n\n#create a sample data frame\ndata = [(1.54,3.45,2.56,0),(9.39,8.31,1.34,0),(1.25,3.31,9.87,1),(9.35,5.67,2.49,2),\\\n (1.23,4.67,8.91,1),(3.56,9.08,7.45,2),(6.43,2.23,1.19,1),(7.89,5.32,9.08,2)]\n\ncols = (\'a\',\'b\',\'c\',\'d\')\n\ndf = spark.createDataFrame(data, cols)\n\nassembler = VectorAssembler(inputCols=[\'a\',\'b\',\'c\'], outputCol=\'features\')\n\ndf_features = assembler.transform(df)\n\n#df.show()\n\ntrain_data, test_data = df_features.randomSplit([0.6,0.4])\n\ndtc = DecisionTreeClassifier(featuresCol=\'features\',labelCol=\'d\')\n\ndtcModel = dtc.fit(train_data)\n\npredictions = dtcModel.transform(test_data)\nRun Code Online (Sandbox Code Playgroud)\n评测部分
\n#important: need to cast to float type, and order by prediction, else it won\'t work\npreds_and_labels = predictions.select([\'predictions\',\'d\']).withColumn(\'label\', F.col(\'d\').cast(FloatType())).orderBy(\'prediction\')\n\n#select only prediction and label columns\npreds_and_labels = preds_and_labels.select([\'prediction\',\'label\'])\n\nmetrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))\n\nprint(metrics.confusionMatrix().toArray())\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
18604 次 |
| 最近记录: |