__init __()得到了意外的关键字参数'n_splits'错误

She*_*ly 0 python scikit-learn cross-validation

我打算尝试此链接中的代码:

我从指向的行中得到了错误StratifiedKFold(n_splits=60)。谁能告诉我如何解决这个错误?

这是代码:

import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.cross_validation import StratifiedKFold

iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y
X, y

cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
                     random_state=random_state)

mean_tpr = 0.0
mean_fpr = np.linspace(0, 1, 100)
Run Code Online (Sandbox Code Playgroud)

这是错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-227-2af2773f4987> in <module>()
----> 1 sklearn.cross_validation.StratifiedKFold(n_splits=6)
      2 #cv = StratifiedKFold(n_splits=6,  shuffle=True, random_state=1)
      3 classifier = svm.SVC(kernel='linear', probability=True,
      4                      random_state=random_state)
      5 

TypeError: __init__() got an unexpected keyword argument 'n_splits'
Run Code Online (Sandbox Code Playgroud)

Viv*_*mar 5

导入sklearn.cross-validation模块时未收到任何警告。这意味着您安装的版本小于0.18。

如果您的scikit-learn版本是< 0.18,则更改以下行:(摘自StratifiedKFold文档中的0.17版

#Notice the extra parameter y and change of name for n_splits to n_folds
cv = StratifiedKFold(y, n_folds=6)

#Also note that the cv is called directly in for loop
for train_index, test_index in cv:
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
Run Code Online (Sandbox Code Playgroud)

如果您的scikit-learn版本是>=0.18,则只有您可以将n_splits参数用于cv:(摘自StratifiedKFold当前文档,这是我认为您所指的)

#Notice the extra parameter y is removed here
cv = StratifiedKFold(n_splits=6)

#Also note that the cv.split() is called here (opposed to cv in ver 0.17 above)
for train_index, test_index in cv.split(X, y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
Run Code Online (Sandbox Code Playgroud)

建议

将您的scikit-learn更新到最新版本0.18。因为您可以通过直接搜索找到的大多数文档都是此版本,这会让您感到困惑。

编辑:

我已经在这里回答了您的类似问题:- 交叉验证问题

因此,下一次,请提及您在问题本身中使用的库的版本,并记住访问其相关文档,而不是其他文档。