Spark DataFrame中向量的访问元素(Logistic回归概率向量)

use*_*916 19 python apache-spark pyspark spark-dataframe apache-spark-ml

我在PySpark(ML包)中训练了LogisticRegression模型,预测结果是PySpark DataFrame(cv_predictions)(参见[1]).该probability列(见[2])是一种vector类型(见[3]).

[1]
type(cv_predictions_prod)
pyspark.sql.dataframe.DataFrame

[2]
cv_predictions_prod.select('probability').show(10, False)
+----------------------------------------+
|probability                             |
+----------------------------------------+
|[0.31559134817066054,0.6844086518293395]|
|[0.8937864350711228,0.10621356492887715]|
|[0.8615878905395029,0.1384121094604972] |
|[0.9594427633777901,0.04055723662220989]|
|[0.5391547673698157,0.46084523263018434]|
|[0.2820729747752462,0.7179270252247538] |
|[0.7730465873083118,0.22695341269168817]|
|[0.6346585276598942,0.3653414723401058] |
|[0.6346585276598942,0.3653414723401058] |
|[0.637279255218404,0.362720744781596]   |
+----------------------------------------+
only showing top 10 rows

[3]
cv_predictions_prod.printSchema()
root
 ...
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)
Run Code Online (Sandbox Code Playgroud)

如何创建解析vectorPySpark DataFrame,以便创建一个新列,只拉取每个probability向量的第一个元素?

这个问题类似于,但下面链接中的解决方案不起作用/我不清楚:

如何在PySpark中访问denseVector的值

如何访问Spark DataFrame中VectorUDT列的元素?

Dav*_*yne 29

更新:

似乎spark中存在一个错误,阻止您在select语句中访问密集向量中的各个元素.通常你应该能像访问一个numpy数组一样访问它们,但是当你试图运行以前发布的代码时,你可能会得到错误pyspark.sql.utils.AnalysisException: "Can't extract value from probability#12;"

所以,处理这个以避免这个愚蠢错误的一种方法是使用udf.与其他问题类似,您可以通过以下方式定义udf:

from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

firstelement=udf(lambda v:float(v[0]),FloatType())
cv_predictions_prod.select(firstelement('probability')).show()
Run Code Online (Sandbox Code Playgroud)

在幕后,这仍然像一个numpy数组一样访问DenseVector的元素,但它不会像以前那样抛出相同的bug.


原始答案:密集向量只是一个numpy数组的包装器.因此,您可以像访问numpy数组的元素一样访问元素.

有几种方法可以访问数据框中数组的各个元素.一种是cv_predictions_prod['probability']在select语句中显式调用该列.通过显式调用该列,您可以对该列执行操作,例如选择数组中的第一个元素.例如:

cv_predictions_prod.select(cv_predictions_prod['probability'][0]).show()
Run Code Online (Sandbox Code Playgroud)

应该解决问题.

  • 不,这行不通。“ VectorUDT”未表示为“ ArrayType”。 (2认同)