Foz*_*oro 6 python machine-learning scikit-learn
我目前正处于我的第一次机器学习阶段,到目前为止我还不太清楚我从中获得的价值的规模decision_function(X)(也不知道如何理解它们)。
基于 sklearn文档 decision_function(X)旨在:
预测样本的置信度分数。
尽管如此,在运行以下脚本时:
from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix , precision_score, recall_score
mnist = fetch_mldata('MNIST original')
classifier = SGDClassifier(random_state = 42, max_iter = 5)
X,y = mnist["data"], mnist["target"]
some_digit = X[36001]
some_digit_image = some_digit.reshape(28, 28)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
random_order = np.random.permutation(60000)
X_train, y_train = X_train[random_order], y_train[random_order]
y_test_5 = (y_test == 5)
y_train_5 = (y_train == 5)
classifier.fit(X_train, y_train_5)
print(classifier.decision_function([X_test[1]]))
Run Code Online (Sandbox Code Playgroud)
它打印出来[-289809.39489525],decision_function此时我不知道如何阅读或如何评估这些值(我期待看到百分比)。如果有人能向我解释这些读数的含义,我将不胜感激。
非常感谢您提前。
如何获得概率(百分比)?
使用predict_proba方法。
什么是 decision_function ?
由于SGDClassifier是线性模型,decision_function输出到分离超平面的有符号距离。这个数字只是 < w , x > + b 或转换为 scikit-learn 属性名称 < coef_, x > + intercept_。
| 归档时间: |
|
| 查看次数: |
3797 次 |
| 最近记录: |