Bel*_*iri 3 python callback keras tensorflow
我正在训练NN,并希望在预测阶段每N个时期保存模型权重.我提出这个草案代码,它的灵感来自于@grovina 在这里的回应.请你提出建议吗?提前致谢.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, model, N):
self.model = model
self.N = N
self.epoch = 0
def on_batch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1
Run Code Online (Sandbox Code Playgroud)
然后将其添加到fit调用:每5个时期保存一次权重:
model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
Run Code Online (Sandbox Code Playgroud)
umu*_*tto 12
您不应该为回调传递模型.它已经通过它的超级访问模型.所以删除__init__(..., model, ...)参数和self.model = model.self.model无论如何,您应该能够访问当前模型.你也在每个批次结束时保存它,这可能不是你想要的,你可能想要它on_epoch_end.
但无论如何,你正在做的事情可以通过天真的modelcheckpoint回调来完成.您不需要编写自定义的.你可以使用如下;
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
6812 次 |
| 最近记录: |