scikit learn:与 GridSearchCV 兼容的自定义分类器

Rod*_*una 6 python machine-learning scikit-learn

我已经实现了自己的分类器,现在我想对其进行网格搜索,但出现以下错误: estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

我跟着这个教程,使用了scikit 官方文档提供的这个模板。我的班级定义如下:

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, lr=0.1):
        self.lr=lr

    def fit(self, X, y):
        # Some code
        return self
    def predict(self, X):
        # Some code
        return y_pred
    def get_params(self, deep=True)
        return {'lr'=self.lr}
    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self
Run Code Online (Sandbox Code Playgroud)

我正在尝试网格搜索将其抛出如下:

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
Run Code Online (Sandbox Code Playgroud)

编辑我

我是这样称呼它的:gs.fit(['hello world', 'trying','hello world', 'trying', 'hello world', 'trying', 'hello world', 'trying'] , ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])

结束编辑我

错误是由_fit_and_score文件中的方法产生的python3.5/site-packages/sklearn/model_selection/_validation.py

estimator.fit(X_train, y_train, **fit_params)使用 3 个参数调用,但我的估算器只有两个,所以错误对我来说很有意义,但我不知道如何解决它......我也尝试向fit方法添加一些虚拟参数,但它没有用.

编辑二

完整的错误输出:

Traceback (most recent call last):
  File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>
    ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit
    cv.split(X, y, groups)))
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
    self._dispatch(tasks)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
    result = ImmediateResult(func)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
    self.results = batch()
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given
Run Code Online (Sandbox Code Playgroud)

结束编辑二

已解决 谢谢大家,我犯了一个愚蠢的错误:有两个不同的函数具有相同的名称(fit),(我使用不同的参数实现了另一个用于自定义目的,一旦我重命名我的“自定义适合”,它就可以正常工作。 )

谢谢和抱歉

Grr*_*Grr 6

以下代码对我有用:

class MyClassifier(BaseEstimator, ClassifierMixin):
     def __init__(self, lr=0.1):
         self.lr = lr
         # Some code
         pass
     def fit(self, X, y):
         # Some code
         pass
     def predict(self, X):
         # Some code
         return X % 3

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
Run Code Online (Sandbox Code Playgroud)

我能想到的最好结果是,您正在将某些内容传递给gs.fit方法之外xy或者您的MyClassifier.fit方法缺少 self 参数。

fit_params如果你传递一个kwarg到kwargs应该只填充gs.fit方法,否则它是一个空的字典({}),并**fit_params不会引发参数错误。要对此进行测试,请创建一个分类器实例并通过**{}。例如:

clf = MyClassifier()
clf.fit(x, y, **{})
Run Code Online (Sandbox Code Playgroud)

这不会引发位置参数错误。

因此,除非将某些内容传递给gs.fit例如,否则gs.fit(x, y, some_arg=123)在我看来您似乎缺少 . 的定义中的位置参数之一MyClassifier.fit。您包含的错误消息似乎支持该假设,正如它所说的那样fit() takes 2 positional arguments but 3 were given。如果您按如下方式定义了 fit ,它将需要 3 个位置参数:

def fit(self, X, y): ...
Run Code Online (Sandbox Code Playgroud)