将keras模型权重直接保存到字节/内存?

Ada*_*hes 6 python keras tensorflow

Keras 允许保存整个模型或仅保存模型权重(请参阅线程)。保存权重时,必须将它们保存到文件中,例如:

model = keras_model()
model.save_weights('/tmp/model.h5')
Run Code Online (Sandbox Code Playgroud)

我只想将字节保存到内存中,而不是写入文件。就像是

model.dump_weights()
Run Code Online (Sandbox Code Playgroud)

Tensorflow 似乎没有这个,所以作为一种解决方法,我写入磁盘然后读入内存:

temp = '/tmp/weights.h5'
model.save_weights(temp)
with open(temp, 'rb') as f:
    weightbytes = f.read()
Run Code Online (Sandbox Code Playgroud)

有什么办法可以避免这种迂回吗?

Ada*_*hes 0

感谢@ddoGas 指出了该model.get_weights()方法,该方法返回一个可以序列化的权重列表。只是一些背景信息,说明为什么我不以传统方式保存模型:我们正在使用将模型和自定义行为关联起来的模型包装器类。例如,在预测发生之前需要进行特殊验证:

class CNN:
   ...
   def predict():
       self.do_special_validation()
       self.model.predict()
Run Code Online (Sandbox Code Playgroud)

因此,我们序列化CNN类而不仅仅是底层模型。这是腌制整个对象的解决方案。(pickle(CNN())失败,否则我们就使用它)

import pickle

def serialize(cnn):
    return pickle.dumps({
        "weights": cnn.model.get_weights(),
        "cnnclass": cnn.__class__
    })

def deserialize(cnn_bytes):
    loaded = pickle.loads(cnn_bytes)
    weights, cnnclass = loaded['weights'], loaded['cnnclass']
    cnninstance = cnnclass()
    cnninstance.model.set_weights(weights)
    return cnninstance
Run Code Online (Sandbox Code Playgroud)

效果很好,谢谢!

PS 注意使用,cnn.__class__因为不一定希望将其CNN直接绑定到类,但它通常适用于任何具有cnn.model属性的类。