实现自定义scikit-learn估算器的完整规范是什么?

rho*_*ron 7 python scikit-learn

我正在使用我自己的预测器,并希望像使用任何scikit例程(例如RandomForestRegressor)一样使用它.我有一个类fit和包含predict似乎工作正常的方法.但是,当我尝试使用某些scikit方法时,例如交叉验证,我会收到如下错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
    for function, args, kwargs in iterable:
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
    % (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.
Run Code Online (Sandbox Code Playgroud)

我看到它希望我实现一些方法(get_params可能set_paramsscore我一样)但我不确定制作这些方法的正确规范是什么.是否有关于此主题的信息?谢谢.

Fre*_*Foo 11

scikit-learn文档中提供了完整的说明,并且您的真实等人本文中阐述了API背后的原理.总之,除了fit,你需要什么的估计是get_paramsset_params那个回报(作为dict),并设置(从kwargs)估计的超参数,即学习算法本身的参数(而不是在数据参数,就会学到).这些参数应与__init__参数匹配.

这两种方法都可以通过继承类中的方法获得sklearn.base,但如果您不希望代码依赖于scikit-learn,则可以自己提供.

请注意,输入验证应该fit在构造函数中完成,而不是构造函数,因为否则您仍然可以设置无效参数set_paramsfit以意外方式失败.