Vid*_*the 7 python classification svm libsvm
我写了这段代码,想获得分类的概率。
from sklearn import svm
X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 4, 5, 6]
clf = svm.SVC()
clf.probability=True
clf.fit(X, y)
prob = clf.predict_proba([[10, 10]])
print prob
Run Code Online (Sandbox Code Playgroud)
我得到了这个输出:
[[0.15376986 0.07691205 0.15388546 0.15389275 0.15386348 0.15383004 0.15384636]]
Run Code Online (Sandbox Code Playgroud)
这很奇怪,因为概率应该是
[0 1 0 0 0 0 0 0]
Run Code Online (Sandbox Code Playgroud)
(注意必须预测类别的样本与第二个样本相同)同样,该类别获得的概率最低。
您应该禁用probability并使用decision_function,因为不能保证predict_proba并predict返回相同的结果。您可以在文档中阅读有关它的更多信息。
clf.predict([[10, 10]]) // returns 1 as expected
prop = clf.decision_function([[10, 10]]) // returns [[ 4.91666667 6.5 3.91666667 2.91666667 1.91666667 0.91666667
-0.08333333]]
prediction = np.argmax(prop) // returns 1
Run Code Online (Sandbox Code Playgroud)
编辑:正如@TimH 所指出的,概率可以由clf.decision_function(X). 下面的代码是固定的。注意到指定的使用概率低的问题predict_proba(X),我认为答案是根据官方文档here,...。此外,它会在非常小的数据集上产生毫无意义的结果。
答案是理解 SVM 的结果概率是多少。简而言之,您在 2D 平面中有 7 个类和 7 个点。SVM 试图做的是在每个类之间找到一个线性分隔符(一对一方法)。每次只选择 2 个班级。你得到的是归一化后分类器的投票。在这篇文章或这里(scikit-learn 使用 libsvm)查看更多关于libsvm 的多类 SVM 的详细解释。
通过稍微修改您的代码,我们看到确实选择了正确的类:
from sklearn import svm
import matplotlib.pyplot as plt
import numpy as np
X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 3, 4, 4]
clf = svm.SVC()
clf.fit(X, y)
x_pred = [[10,10]]
p = np.array(clf.decision_function(x_pred)) # decision is a voting function
prob = np.exp(p)/np.sum(np.exp(p),axis=1, keepdims=True) # softmax after the voting
classes = clf.predict(x_pred)
_ = [print('Sample={}, Prediction={},\n Votes={} \nP={}, '.format(idx,c,v, s)) for idx, (v,s,c) in enumerate(zip(p,prob,classes))]
Run Code Online (Sandbox Code Playgroud)
对应的输出是
Sample=0, Prediction=0,
Votes=[ 6.5 4.91666667 3.91666667 2.91666667 1.91666667 0.91666667 -0.08333333]
P=[ 0.75531071 0.15505748 0.05704246 0.02098475 0.00771986 0.00283998 0.00104477],
Sample=1, Prediction=1,
Votes=[ 4.91666667 6.5 3.91666667 2.91666667 1.91666667 0.91666667 -0.08333333]
P=[ 0.15505748 0.75531071 0.05704246 0.02098475 0.00771986 0.00283998 0.00104477],
Sample=2, Prediction=2,
Votes=[ 1.91666667 2.91666667 6.5 4.91666667 3.91666667 0.91666667 -0.08333333]
P=[ 0.00771986 0.02098475 0.75531071 0.15505748 0.05704246 0.00283998 0.00104477],
Sample=3, Prediction=3,
Votes=[ 1.91666667 2.91666667 4.91666667 6.5 3.91666667 0.91666667 -0.08333333]
P=[ 0.00771986 0.02098475 0.15505748 0.75531071 0.05704246 0.00283998 0.00104477],
Sample=4, Prediction=4,
Votes=[ 1.91666667 2.91666667 3.91666667 4.91666667 6.5 0.91666667 -0.08333333]
P=[ 0.00771986 0.02098475 0.05704246 0.15505748 0.75531071 0.00283998 0.00104477],
Sample=5, Prediction=5,
Votes=[ 3.91666667 2.91666667 1.91666667 0.91666667 -0.08333333 6.5 4.91666667]
P=[ 0.05704246 0.02098475 0.00771986 0.00283998 0.00104477 0.75531071 0.15505748],
Sample=6, Prediction=6,
Votes=[ 3.91666667 2.91666667 1.91666667 0.91666667 -0.08333333 4.91666667 6.5 ]
P=[ 0.05704246 0.02098475 0.00771986 0.00283998 0.00104477 0.15505748 0.75531071],
Run Code Online (Sandbox Code Playgroud)
您还可以看到决策区:
X = np.array(X)
y = np.array(y)
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
XX, YY = np.mgrid[0:100:200j, 0:100:200j]
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()])
Z = Z.reshape(XX.shape)
plt.figure(1, figsize=(4, 3))
plt.pcolormesh(XX, YY, Z, cmap=plt.cm.Paired)
for idx in range(7):
ax.scatter(X[idx,0],X[idx,1], color='k')
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
11234 次 |
| 最近记录: |