如何保存 GridSearchCV 对象?

E.T*_*dis 9 python save scikit-learn keras grid-search

最近,我一直致力于应用网格搜索交叉验证(sklearn GridSearchCV)在带有 Tensorflow 后端的 Keras 中进行超参数调整。调整我的模型后,我试图保存 GridSearchCV 对象以备后用,但没有成功。

超参数调整如下:

x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)

history = History() 
kfold = 10


regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)

neurons = np.arange(10,101,10) 
hidden_layers = [1,2]
optimizer = ['adam','sgd']
activation = ['relu'] 
dropout = [0.1] 

parameters = dict(neurons = neurons,
                  hidden_layers = hidden_layers,
                  optimizer = optimizer,
                  activation = activation,
                  dropout = dropout)

gs = GridSearchCV(estimator = regressor,
                  param_grid = parameters,
                  scoring='mean_squared_error',
                  n_jobs = 1,
                  cv = kfold,
                  verbose = 3,
                  return_train_score=True))

grid_result = gs.fit(NN_input,
                    NN_target,
                    callbacks=[history],
                    verbose=1,
                    validation_data=(x_val, y_val))
Run Code Online (Sandbox Code Playgroud)

备注:create_keras_model 函数初始化和编译 Keras Sequential 模型。

执行交叉验证后,我尝试使用以下代码保存网格搜索对象 (gs):

from sklearn.externals import joblib

joblib.dump(gs, 'GS_obj.pkl')
Run Code Online (Sandbox Code Playgroud)

我得到的错误如下:

TypeError: can't pickle _thread.RLock objects
Run Code Online (Sandbox Code Playgroud)

你能告诉我这个错误的原因是什么吗?

谢谢!

PS:joblib.dump 方法可以很好地保存用于从 sklearn 训练 MLPRegressors 的 GridSearchCV 对象。

mak*_*kis 9

尝试这个:

from sklearn.externals import joblib
joblib.dump(gs.best_estimator_, 'filename.pkl')
Run Code Online (Sandbox Code Playgroud)

如果您想将对象转储到一个文件中 - 使用:

joblib.dump(gs.best_estimator_, 'filename.pkl', compress = 1)
Run Code Online (Sandbox Code Playgroud)

简单示例:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.externals import joblib

iris = datasets.load_iris()
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
svc = svm.SVC()
gs = GridSearchCV(svc, parameters)
gs.fit(iris.data, iris.target)

joblib.dump(gs.best_estimator_, 'filename.pkl')

#['filename.pkl']
Run Code Online (Sandbox Code Playgroud)

编辑1:

您还可以保存整个对象:

joblib.dump(gs, 'gs_object.pkl')
Run Code Online (Sandbox Code Playgroud)


lie*_*dji 9

import joblib 直接地

代替

from sklearn.externals import joblib

使用以下方法保存对象或结果:

joblib.dump(gs, 'model_file_name.pkl')

并使用以下方法加载您的结果:

joblib.load("model_file_name.pkl")

这是一个简单的工作示例:


import joblib

#save your model or results
joblib.dump(gs, 'model_file_name.pkl')

#load your model for further usage
joblib.load("model_file_name.pkl")

Run Code Online (Sandbox Code Playgroud)