Des*_* pv 6 machine-learning apache-spark apache-spark-mllib
我试图在图中绘制ROC曲线和Precision-Recall曲线.这些点是从Spark Mllib BinaryClassificationMetrics生成的.按照以下Spark https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html
[(1.0,1.0), (0.0,0.4444444444444444)] Precision
[(1.0,1.0), (0.0,1.0)] Recall
[(1.0,1.0), (0.0,0.6153846153846153)] - F1Measure
[(0.0,1.0), (1.0,1.0), (1.0,0.4444444444444444)]- Precision-Recall curve
[(0.0,0.0), (0.0,1.0), (1.0,1.0), (1.0,1.0)] - ROC curve
Run Code Online (Sandbox Code Playgroud)
看起来你和我的经历有类似的问题.您需要将参数翻转到Metrics构造函数,或者传递概率而不是预测.因此,例如,如果您使用的是BinaryClassificationMetrics和a RandomForestClassifier,则根据此页面(在输出下)有"预测"和"概率".
然后初始化您的度量标准:
new BinaryClassificationMetrics(predictionsWithResponse
.select(col("probability"),col("myLabel"))
.rdd.map(r=>(r.getAs[DenseVector](0)(1),r.getDouble(1))))
Run Code Online (Sandbox Code Playgroud)
用DenseVector调用来提取1类的概率.
至于实际绘图,这取决于你(很多很好的工具),但至少你会在曲线上获得超过1点(除了端点).
如果不清楚:
metrics.roc().collect() 将为您提供ROC曲线的数据:元组:(误报率,真阳性率).
| 归档时间: |
|
| 查看次数: |
7279 次 |
| 最近记录: |