Sklearn - 自动定义get_params()

dia*_*hos 4 python scikit-learn

我正在尝试定义一个符合Sklearn估算器的类,例如

class MyEstimator():
    def __init__(self,verbose=False):
        self.verbose = verbose

    def get_params(self, deep=False):
        return {
            'verbose': self.verbose,
        }

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

    # Also def fit() and other stuff ...
Run Code Online (Sandbox Code Playgroud)

set_params()可以定义而无需显式列出所有参数名称.有没有办法以get_params()类似的方式定义?

从Sklearn我需要的是GridsearchCV,从我尝试过的,它似乎get_params确定了在交叉验证期间可以注入哪些参数.

Max*_*axU 9

只是继承类BaseEstimator,它实现get_params()set_params()为您服务.

演示:

In [21]: from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, ClusterMixin

In [22]: from sklearn.base import BaseEstimator
    ...:
    ...: class MyEstimator(BaseEstimator):
    ...:     def __init__(self,verbose=False):
    ...:         self.verbose = verbose

In [23]: est = MyEstimator(verbose=True)

In [24]: est.get_params()
Out[24]: {'verbose': True}

In [25]: est.set_params(verbose=False)
Out[25]: MyEstimator(verbose=False)

In [26]: est.get_params()
Out[26]: {'verbose': False}
Run Code Online (Sandbox Code Playgroud)

PS你可能还需要也继承你估计从一个(ClassifierMixin,RegressorMixin,ClusterMixin),这取决于你要实现什么样的估计的...