在scikit-learn中训练神经网络的早期停止

use*_*384 4 python machine-learning neural-network scikit-learn cross-validation

这个问题非常具体到Python库scikit-learn.如果最好将其发布到其他地方,请告诉我.谢谢!

现在的问题......

我有一个基于BaseEstimator的前馈神经网络类ffnn,我用SGD训练.它工作正常,我也可以使用GridSearchCV()并行训练它.

现在我想在函数ffnn.fit()中实现提前停止,但为此我还需要访问fold的验证数据.一种方法是更改​​sklearn.grid_search.fit_grid_point()中的行

clf.fit(X_train, y_train, **fit_params)
Run Code Online (Sandbox Code Playgroud)

变成类似的东西

clf.fit(X_train, y_train, X_test, y_test, **fit_params)
Run Code Online (Sandbox Code Playgroud)

并且还改变ffnn.fit()以获取这些参数.这也会影响sklearn中的其他分类器,这是一个问题.我可以通过检查fit_grid_point()中的某种标志来避免这种情况,该标志告诉我何时以上述两种方式之一调用clf.fit().

有人可以建议一种不同的方式来执行此操作,我不必编辑sklearn库中的任何代码吗?

或者,将X_train和y_train进一步拆分为火车/验证集并检查一个好的停止点,然后在所有X_train上重新训练模型是否正确?

谢谢!

ogr*_*sel 7

您可以让神经网络模型从内部提取验证集X_trainy_train通过使用train_test_split函数进行实例化.

编辑:

或者,将X_train和y_train进一步拆分为火车/验证集并检查一个好的停止点,然后在所有X_train上重新训练模型是否正确?

是的,但那会很贵.您可以找到停止点,然后只需对用于查找停止点的验证数据执行一次额外的传递.