如何在scikit-learn中通过GridSearchCV调整嵌套管道的参数?

liz*_*isk 9 scikit-learn

是否可以在scikit-learn中调整嵌套管道的参数?例如:

svm = Pipeline([
    ('chi2', SelectKBest(chi2)),
    ('cls', LinearSVC(class_weight='auto'))
])

classifier = Pipeline([
    ('vectorizer', TfIdfVectorizer()),
    ('ova_svm', OneVsRestClassifier(svm))
})

parameters = ?

GridSearchCV(classifier, parameters)
Run Code Online (Sandbox Code Playgroud)

如果不能直接这样做,可能是一种解决方法?

Fre*_*Foo 16

scikit-learn有一个双下划线符号,如下所示.它以递归方式工作并延伸到OneVsRestClassifier,但必须明确地将基础估算器明确地解决为__estimator:

parameters = {'ova_svm__estimator__cls__C': [1, 10, 100],
              'ova_svm__estimator__chi2_k': [200, 500, 1000]}
Run Code Online (Sandbox Code Playgroud)


小智 11

对于您创建的估算器,您可以使用其标签获取参数列表,如下所示.

import pprint as pp

pp.pprint(sorted(classifier.get_params().keys()))
Run Code Online (Sandbox Code Playgroud)

[ 'ova_svm', 'ova_svm__estimator', 'ova_svm__estimator__chi2', 'ova_svm__estimator__chi2__k', 'ova_svm__estimator__chi2__score_func', 'ova_svm__estimator__cls', 'ova_svm__estimator__cls__C', 'ova_svm__estimator__cls__class_weight', 'ova_svm__estimator__cls__dual', 'ova_svm__estimator__cls__fit_intercept', 'ova_svm__estimator__cls__intercept_scaling', 'ova_svm__estimator__cls__loss',' ova_svm__estimator__cls__max_iter", 'ova_svm__estimator__cls__multi_class', 'ova_svm__estimator__cls__penalty', 'ova_svm__estimator__cls__random_state', 'ova_svm__estimator__cls__tol', 'ova_svm__estimator__cls__verbose', 'ova_svm__estimator__steps', 'ova_svm__n_jobs', '步骤', '矢量化', 'vectorizer__analyzer', 'vectorizer__binary', 'vectorizer__decode_error' ,'vectorizer__dtype','vectorizer__encoding','vectorizer__input','vectorizer__lowercase','vectorizer__max_df','vectorizer__max_features','vectorizer__min_df','vectorizer__ngram_range','vectorizer__norm','vectorizer__preprocessor','vectorize r__smooth_idf','vectorizer__stop_words','vectorizer__strip_accents','vectorizer__sublinear_tf','vectorizer__token_pattern','vectorizer__tokenizer','vectorizer__use_idf','vectorizer__vocabulary']

然后,您可以从此列表中设置要在其上执行GridSearchCV的参数.