gga*_*rav 5 python machine-learning scikit-learn
我正在尝试创建一个基于 scikit learn 的自定义估算器。我写了下面的虚拟代码来解释我的问题。在 score 方法中,我试图访问mean_calculated in fit。但我无能为力。我做错了什么?我已经尝试了很多事情,并参考了三四篇文章。但是没发现问题。
我已阅读文档并做了一些更改。但没有任何效果。我也尝试过继承BaseEstimator, ClassifierMixin. 但这也不起作用。
这是一个虚拟程序。不要按照它试图做的事情去做。
import numpy as np
from sklearn.model_selection import cross_val_score
class FilterElems:
def __init__(self, thres):
self.thres = thres
def fit(self, X, y=None, **kwargs):
self.mean_ = np.mean(X)
self.std_ = np.std(X)
return self
def predict(self, X):
# return sign(self.predict(inputs))
X = (X - self.mean_) / self.std_
return X[X > self.thres]
def get_params(self, deep=False):
return {'thres': self.thres}
def score(self, *x):
print(self.mean_) # errors out, mean_ and std_ are wiped out
if len(x[1]) > 50:
return 1.0
else:
return 0.5
model = FilterElems(thres=0.5)
print(cross_val_score(model,
np.random.randint(1, 1000, (100, 100)),
None,
scoring=model.score,
cv=5))
Run Code Online (Sandbox Code Playgroud)
呃:
AttributeError: 'FilterElems' 对象没有属性 'mean_'
你快到了。
记分员的签名是scorer(estimator, X, y)。通过将对象作为第一个参数传递来调用cross_val_score该方法。由于您的签名是一个可变参数函数,因此第一项将包含scorerestimatorscorerestimator
将您的分数更改为
def score(self, *x):
print(x[0].mean_)
if len(x[1]) > 50:
return 1.0
else:
return 0.5
Run Code Online (Sandbox Code Playgroud)
工作代码
import numpy as np
from sklearn.model_selection import cross_val_score
class FilterElems:
def __init__(self, thres):
self.thres = thres
def fit(self, X, y=None, **kwargs):
self.mean_ = np.mean(X)
self.std_ = np.std(X)
return self
def predict(self, X):
X = (X - self.mean_) / self.std_
return X[X > self.thres]
def get_params(self, deep=False):
return {'thres': self.thres}
def score(self, estimator, *x):
print(estimator.mean_, estimator.std_)
if len(x[0]) > 50:
return 1.0
else:
return 0.5
model = FilterElems(thres=0.5)
print(cross_val_score(model,
np.random.randint(1, 1000, (100, 100)),
None,
scoring=model.score,
cv=5))
Run Code Online (Sandbox Code Playgroud)
输出
504.750125 288.84916035447355
501.7295 289.47825925231416
503.743375 288.8964170227962
503.0325 287.8292687406025
500.041 289.3488678377712
[0.5 0.5 0.5 0.5 0.5]
Run Code Online (Sandbox Code Playgroud)