相关疑难解决方法(0)

创建keras回调以在培训期间保存每个批次的模型预测和目标

我正在Keras(tensorflow后端)构建一个简单的Sequential模型.在培训期间,我想检查各个培训批次和模型预测.因此,我正在尝试创建一个自定义Callback,以保存每个培训批次的模型预测和目标.但是,该模型不使用当前批次进行预测,而是使用整个训练数据.

我怎样才能将当前的培训批次交给Callback

我如何访问Callbackself.predhis和self.targets中保存的批次和目标?

我当前的版本如下:

callback_list = [prediction_history((self.x_train, self.y_train))]

self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.n_epochs, validation_data=(self.x_val, self.y_val), callbacks=callback_list)

class prediction_history(keras.callbacks.Callback):
    def __init__(self, train_data):
        self.train_data = train_data
        self.predhis = []
        self.targets = []

    def on_batch_end(self, epoch, logs={}):
        x_train, y_train = self.train_data
        self.targets.append(y_train)
        prediction = self.model.predict(x_train)
        self.predhis.append(prediction)
        tf.logging.info("Prediction shape: {}".format(prediction.shape))
        tf.logging.info("Targets shape: {}".format(y_train.shape))
Run Code Online (Sandbox Code Playgroud)

callback keras tensorflow

17
推荐指数
2
解决办法
9139
查看次数

标签 统计

callback ×1

keras ×1

tensorflow ×1