Ada*_*hes 6 python keras tensorflow
Keras 允许保存整个模型或仅保存模型权重(请参阅线程)。保存权重时,必须将它们保存到文件中,例如:
model = keras_model()
model.save_weights('/tmp/model.h5')
Run Code Online (Sandbox Code Playgroud)
我只想将字节保存到内存中,而不是写入文件。就像是
model.dump_weights()
Run Code Online (Sandbox Code Playgroud)
Tensorflow 似乎没有这个,所以作为一种解决方法,我写入磁盘然后读入内存:
temp = '/tmp/weights.h5'
model.save_weights(temp)
with open(temp, 'rb') as f:
weightbytes = f.read()
Run Code Online (Sandbox Code Playgroud)
有什么办法可以避免这种迂回吗?
感谢@ddoGas 指出了该model.get_weights()
方法,该方法返回一个可以序列化的权重列表。只是一些背景信息,说明为什么我不以传统方式保存模型:我们正在使用将模型和自定义行为关联起来的模型包装器类。例如,在预测发生之前需要进行特殊验证:
class CNN:
...
def predict():
self.do_special_validation()
self.model.predict()
Run Code Online (Sandbox Code Playgroud)
因此,我们序列化CNN
类而不仅仅是底层模型。这是腌制整个对象的解决方案。(pickle(CNN())
失败,否则我们就使用它)
import pickle
def serialize(cnn):
return pickle.dumps({
"weights": cnn.model.get_weights(),
"cnnclass": cnn.__class__
})
def deserialize(cnn_bytes):
loaded = pickle.loads(cnn_bytes)
weights, cnnclass = loaded['weights'], loaded['cnnclass']
cnninstance = cnnclass()
cnninstance.model.set_weights(weights)
return cnninstance
Run Code Online (Sandbox Code Playgroud)
效果很好,谢谢!
PS 注意使用,cnn.__class__
因为不一定希望将其CNN
直接绑定到类,但它通常适用于任何具有cnn.model
属性的类。
归档时间: |
|
查看次数: |
5019 次 |
最近记录: |