如何腌制Keras模型?

Sid*_*hou 11 python machine-learning pickle keras

官方文件声明"不建议使用pickle或cPickle来保存Keras模型."

然而,我对酸洗Keras模型的需求源于使用sklearn的RandomizedSearchCV(或任何其他超参数优化器)的超参数优化.将结果保存到文件中至关重要,因为脚本可以在分离的会话中远程执行等.

基本上,我想:

trial_search = RandomizedSearchCV( estimator=keras_model, ... )
pickle.dump( trial_search, open( "trial_search.pickle", "wb" ) )
Run Code Online (Sandbox Code Playgroud)

far*_*n4u 7

到目前为止,Keras型号是可腌制的。但是我们仍然建议使用model.save()将模型保存到磁盘。

  • model.save() 在 python3 中存在问题。每次我加载模型时,它都会针对相同的输入预测不同的结果。 (8认同)
  • 使用 model.save() 保存的模型将与 Keras 的未来版本兼容,也可以导出到其他平台和实现(deeplearning4j、Apple CoreML 等)。 (2认同)
  • 我尝试腌制,但它提出了“不能腌制弱引用”。腌制之前需要做什么吗? (2认同)

Sid*_*hou 5

这就像一个魅力http://zachmoshe.com/2017/04/03/pickling-keras-models.html

import types
import tempfile
import keras.models

def make_keras_picklable():
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            keras.models.save_model(self, fd.name, overwrite=True)
            model_str = fd.read()
        d = { 'model_str': model_str }
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = keras.models.load_model(fd.name)
        self.__dict__ = model.__dict__


    cls = keras.models.Model
    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

make_keras_picklable()
Run Code Online (Sandbox Code Playgroud)

PS。由于循环引用model.to_json()引发了我的问题,TypeError('Not JSON Serializable:', obj)上面的代码以某种方式吞噬了此错误,从而导致pickle函数永远运行。


Anu*_*pta 5

分别使用get_weights和set_weights保存和加载模型。

看一下此链接: 无法将DataFrame保存到HDF5(“对象头消息太大”)

#for heavy model architectures, .h5 file is unsupported.
weigh= model.get_weights();    pklfile= "D:/modelweights.pkl"
try:
    fpkl= open(pklfile, 'wb')    #Python 3     
    pickle.dump(weigh, fpkl, protocol= pickle.HIGHEST_PROTOCOL)
    fpkl.close()
except:
    fpkl= open(pklfile, 'w')    #Python 2      
    pickle.dump(weigh, fpkl, protocol= pickle.HIGHEST_PROTOCOL)
    fpkl.close()
Run Code Online (Sandbox Code Playgroud)