sklearn 模型是什么类型?

FCh*_*Chm 14 python scikit-learn

我正在编写一些代码,根据某些数据评估不同的 sklearn 模型。我正在使用类型提示,既是为了我自己的教育,也是为了帮助最终必须阅读我的代码的其他人。

我的问题是如何指定 sklearn 预测器的类型(例如LinearRegression())?

例如:

def model_tester(model : Predictor,
                 parameter: int
                 ) -> np.ndarray:
     """An example function with type hints."""

     # do stuff to model 

     return values
Run Code Online (Sandbox Code Playgroud)

我看到打字库可以创建新类型,或者我可以TypeVar用来做:

Predictor = TypeVar('Predictor') 
Run Code Online (Sandbox Code Playgroud)

但如果 sklearn 模型已经有传统类型,我就不想使用它。

检查 LinearRegression() 的类型产生:

 sklearn.linear_model.base.LinearRegression
Run Code Online (Sandbox Code Playgroud)

这显然很有用,但前提是我对 LinearRegression 模型感兴趣。

Pet*_*ter 13

从 Python 3.8 开始(或更早版本使用Typing-extensions),您可以使用typing.Protocol. 使用协议,您可以使用称为结构子类型的概念来准确定义类型的预期结构:

from typing import Protocol
# from typing_extensions import Protocol  # for Python <3.8

class ScikitModel(Protocol):
    def fit(self, X, y, sample_weight=None): ...
    def predict(self, X): ...
    def score(self, X, y, sample_weight=None): ...
    def set_params(self, **params): ...
Run Code Online (Sandbox Code Playgroud)

然后您可以将其用作类型提示:

def do_stuff(model: ScikitModel) -> Any:
    model.fit(train_data, train_labels)  # this type checks 
    score = model.score(test_data, test_labels)  # this type checks
    ...
Run Code Online (Sandbox Code Playgroud)


Flo*_*nGD 9

我认为所有模型继承的最通用的类​​是sklearn.base.BaseEstimator.

如果您想更具体,可以使用sklearn.base.ClassifierMixinsklearn.base.RegressorMixin

所以我会这样做:

from sklearn.base import RegressorMixin


def model_tester(model: RegressorMixin, parameter: int) -> np.ndarray:
     """An example function with type hints."""

     # do stuff to model 

     return values
Run Code Online (Sandbox Code Playgroud)

我不是类型检查的专家,如果这不对,请纠正我。

  • 谢谢您的回答。我尝试了 BaseEstimator 和 ClassifierMixin。但是当我调用 self.estimator.fit 时,我的 IDE (Pycharm) 抱怨它找不到“fit”属性。这是对的。这些类没有实现适配。它是为每个估计器(例如 LogisticRegression)单独实现的。有谁知道应该采用 scikit-learn 估计器的参数的正确类型提示是什么? (3认同)

小智 7

您可以将类型设置为sklearn.pipeline.Pipeline随处可见。这可能是一种无需创建额外实体的解决方案。


小智 6

一个好的解决方法是创建您自己的自定义类型提示类(使用 Union),其中包含您常用的所有模型。它需要更多的努力,但允许您具体并使用 PyCharm。

ModelRegressor = Union[LinearRegression, DecisionTreeRegressor, RandomForestRegressor, SVR]

def foo(model: ModelRegressor):
    do_something
Run Code Online (Sandbox Code Playgroud)