我正在Keras(tensorflow后端)构建一个简单的Sequential模型.在培训期间,我想检查各个培训批次和模型预测.因此,我正在尝试创建一个自定义Callback,以保存每个培训批次的模型预测和目标.但是,该模型不使用当前批次进行预测,而是使用整个训练数据.
我怎样才能将当前的培训批次交给Callback?
我如何访问Callbackself.predhis和self.targets中保存的批次和目标?
我当前的版本如下:
callback_list = [prediction_history((self.x_train, self.y_train))]
self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.n_epochs, validation_data=(self.x_val, self.y_val), callbacks=callback_list)
class prediction_history(keras.callbacks.Callback):
def __init__(self, train_data):
self.train_data = train_data
self.predhis = []
self.targets = []
def on_batch_end(self, epoch, logs={}):
x_train, y_train = self.train_data
self.targets.append(y_train)
prediction = self.model.predict(x_train)
self.predhis.append(prediction)
tf.logging.info("Prediction shape: {}".format(prediction.shape))
tf.logging.info("Targets shape: {}".format(y_train.shape))
Run Code Online (Sandbox Code Playgroud)