如何使用 scikit-learn API 实现元估计器?

Ger*_*nuk 6 python scikit-learn

我想实现一个与所有 scikit-learn 兼容的简单包装器/元估计器。很难找到我到底需要什么的完整描述。

目标是让回归器也学习阈值以成为分类器。所以我想出了:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y
Run Code Online (Sandbox Code Playgroud)

这是否实现了我需要的完整 API?

我的主要问题是将threshold. 我希望它只学习一次,并且可以在后续.fit调用中重新使用新数据而无需重新调整。但是对于当前版本,每次.fit调用都必须重新调整- 我不想要?

另一方面,如果我self.threshold将其设为固定参数并将其传递给__init__,那么我不应该用数据更改它?

如何制作一个threshold可以在一次调用中调整.fit并在后续.fit调用中修复的参数?

Adi*_*hya 1

实际上,前几天我写了一篇关于此的博客文章。我假设您正在尝试构建类似的东西,TransformedTargetRegressor我建议您查看其源代码来构建类似的东西。

您当前的实施似乎是正确的。就这个问题而言:

如何制作一个可以在一次调用中调整.fit并在后续调用中修复的阈值参数.fit

我建议反对,因为scikit-learn的 API 是基于fit重新拟合模型的所有可调方面的方法。您可以在此处选择两条路线,要么添加一个**kwarg明确防止更新的配合theshold,要么您可以采用@rotem-tal 建议的方式。如果您选择后者,它可能看起来像这样:

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin

def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
    return np.array([0.1, 0.5, 1])  # some implementation here

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        self.threshold = None

    def fit(self, X, y, optimal_threshold):
        # you don't need to clone the regressor
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        if self.threshold is None:
            self.threshold = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y
Run Code Online (Sandbox Code Playgroud)