了解decision_function 值

Foz*_*oro 6 python machine-learning scikit-learn

我目前正处于我的第一次机器学习阶段,到目前为止我还不太清楚我从中获得的价值的规模decision_function(X)(也不知道如何理解它们)。

基于 sklearn文档 decision_function(X)旨在:

预测样本的置信度分数。

尽管如此,在运行以下脚本时:

from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix , precision_score, recall_score


mnist = fetch_mldata('MNIST original')

classifier = SGDClassifier(random_state = 42, max_iter = 5)


X,y = mnist["data"], mnist["target"]
some_digit = X[36001]
some_digit_image = some_digit.reshape(28, 28)

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

random_order = np.random.permutation(60000)

X_train, y_train = X_train[random_order], y_train[random_order]

y_test_5 = (y_test == 5)
y_train_5 = (y_train == 5)


classifier.fit(X_train, y_train_5)
print(classifier.decision_function([X_test[1]]))
Run Code Online (Sandbox Code Playgroud)

它打印出来[-289809.39489525]decision_function此时我不知道如何阅读或如何评估这些值(我期待看到百分比)。如果有人能向我解释这些读数的含义,我将不胜感激。

非常感谢您提前。

Jan*_*n K 7

如何获得概率(百分比)?

使用predict_proba方法。

什么是 decision_function

由于SGDClassifier是线性模型,decision_function输出到分离超平面的有符号距离。这个数字只是 < w , x > + b 或转换为 scikit-learn 属性名称 < coef_, x > + intercept_

  • 如需进一步参考,您可以查看:https://datascience.stackexchange.com/questions/18374/predicting-probability-from-scikit-learn-svc-decision-function-with-decision-fun/18375#18375 (2认同)