(Python - sklearn)如何通过gridsearchcv将参数传递给自定义的ModelTransformer类

nkh*_*uyu 18 machine-learning parameter-passing python-2.7 scikit-learn cross-validation

下面是我的管道,似乎我不能通过使用ModelTransformer类将参数传递给我的模型,我从链接中获取它(http://zacstewart.com/2014/08/05/pipelines-of- featureunions-of-pipelines.html)

错误信息对我有意义,但我不知道如何解决这个问题.知道如何解决这个问题吗?谢谢.

# define a pipeline
pipeline = Pipeline([
('vect', DictVectorizer(sparse=False)),
('scale', preprocessing.MinMaxScaler()),
('ess', FeatureUnion(n_jobs=-1, 
                     transformer_list=[
     ('rfc', ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100))),
     ('svc', ModelTransformer(SVC(random_state=1))),],
                     transformer_weights=None)),
('es', EnsembleClassifier1()),
])

# define the parameters for the pipeline
parameters = {
'ess__rfc__n_estimators': (100, 200),
}

# ModelTransformer class. It takes it from the link
(http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html)
class ModelTransformer(TransformerMixin):
    def __init__(self, model):
        self.model = model
    def fit(self, *args, **kwargs):
        self.model.fit(*args, **kwargs)
        return self
    def transform(self, X, **transform_params):
        return DataFrame(self.model.predict(X))

grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, refit=True)
Run Code Online (Sandbox Code Playgroud)

错误消息:ValueError:估计器ModelTransformer的参数n_estimators无效.

Art*_*lev 18

GridSearchCV有嵌套对象的特殊命名约定.在你的情况ess__rfc__n_estimators代表ess.rfc.n_estimators,并根据的定义pipeline,它指向的财产n_estimators

ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100)))
Run Code Online (Sandbox Code Playgroud)

显然,ModelTransformer实例没有这样的属性.

修复很简单:为了访问ModelTransformer一个需要使用model字段的底层对象.因此,网格参数变为

parameters = {
  'ess__rfc__model__n_estimators': (100, 200),
}
Run Code Online (Sandbox Code Playgroud)

PS它不是你的代码唯一的问题.要在GridSearchCV中使用多个作业,您需要使用可复制的所有对象.这是通过实施方法实现get_paramsset_params,你可以从借阅BaseEstimator混入.

  • @B_Miner,你应该从[`BaseEstimator`](http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html)继承你的`SelectColumns`类,它提供了前面提到的`set_params`和`get_params`.或者,你可以实现自己的,但大多数时候你不想这样做. (8认同)
  • 我在寻找BaseEstimatorMixin.我继承了BaseEstimator,它就像一个魅力,谢谢! (2认同)