Tensorflow Callback:如何将最佳模型保存在内存而不是磁盘上

New*_*per 5 python callback tensorflow

我使用 Tensorflow 使用以下函数进行回归

import tensorflow as tf

def ff(*args, **kwargs):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.Input(shape=[inp_train.shape[-1],]))
    for i in range(n_layer):
        model.add(tf.keras.layers.Dense(n_unit, activation=act))
    model.add(tf.keras.layers.Dense(out_train.shape[1]))
    model.compile(optimizer=opt, loss='mae')
    early_stop  = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)
    check_point = tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)
    model.fit(inp_train, out_train, epochs=n_epoch, batch_size=s_batch, validation_data=(inp_val, out_val), callbacks=[early_stop, check_point], verbose=0)
    best_model = tf.keras.models.load_model('best_model.h5')
    return model, best_mode
Run Code Online (Sandbox Code Playgroud)

如您所见,我通过回调保存最佳模型check_point,并在以后使用它进行预测。问题是,这样我必须先将最佳模型保存在磁盘上,然后再从磁盘加载它。如果我想并行进行几次运行,因为每次运行都会创建一个具有相同名称的文件,所以它不起作用。

那么,如何在变量中分配最佳模型而不必将其保存在磁盘上呢?

use*_*087 6

注意:我修复了一个错误并且未经测试

我必须自己做这件事并认为我会分享:

打回来:

class SaveBestModel(tf.keras.callbacks.Callback):
    def __init__(self, save_best_metric='val_loss', this_max=False):
        self.save_best_metric = save_best_metric
        self.max = this_max
        if this_max:
            self.best = float('-inf')
        else:
            self.best = float('inf')

    def on_epoch_end(self, epoch, logs=None):
        metric_value = logs[self.save_best_metric]
        if self.max:
            if metric_value > self.best:
                self.best = metric_value
                self.best_weights = self.model.get_weights()

        else:
            if metric_value < self.best:
                self.best = metric_value
                self.best_weights= self.model.get_weights()
Run Code Online (Sandbox Code Playgroud)

用法:

save_best_model = SaveBestModel()
model.fit(data, callbacks=[save_best_model]
#set best weigts
model.set_weights(save_best_model.best_weights)
Run Code Online (Sandbox Code Playgroud)