访问 tf.keras.callbacks.Callback 中已弃用的属性“validation_data”

Con*_*tih 6 python callback keras tensorflow

我决定从 keras 切换到 tf.keras(这里推荐)。因此,我安装tf.__version__=2.0.0tf.keras.__version__=2.2.4-tf。在我的代码的较旧版本(使用一些较旧的 Tensorflow 版本tf.__version__=1.x.x)中,我使用回调来计算每个时期结束时整个验证数据的自定义指标。这样做的想法来自这里。但是,似乎不推荐使用“validation_data”属性,因此以下代码不再起作用。

class ValMetrics(Callback):

    def on_train_begin(self, logs={}):

        self.val_all_mse = []

    def on_epoch_end(self, epoch, logs):

        val_predict = np.asarray(self.model.predict(self.validation_data[0]))
        val_targ = self.validation_data[1]

        val_epoch_mse = mse_score(val_targ, val_predict)

        self.val_epoch_mse.append(val_epoch_mse)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_epoch_mse"] = val_epoch_mse

        print(f"\nEpoch: {epoch + 1}")
        print("-----------------")
        print("val_mse:     {:+.6f}".format(val_epoch_mse))

        return
Run Code Online (Sandbox Code Playgroud)

我目前的解决方法如下。我只是将validation_data作为ValMetrics类的参数:

class ValMetrics(Callback):

    def __init__(self, validation_data):
        super(Callback, self).__init__()
        self.X_val, self.y_val = validation_data
Run Code Online (Sandbox Code Playgroud)

我仍然有一些问题:“validation_data”属性真的被弃用了还是可以在其他地方找到?有没有比上述解决方法更好的方法来访问每个时期结束时的验证数据?

非常感谢!

Ten*_*ior 5

您是对的,validation_data根据Tensorflow 回调文档,该参数已被弃用。

您所面临的问题已在 Github 中提出。相关问题Issue1Issue2Issue3

上述 Github 问题均未解决Validation_Data,并且根据此Github 评论,您将作为参数传递给自定义回调的解决方法是一个很好的解决方法,因为许多人发现它很有用。

指定下面的解决方法的代码,为了Stackoverflow Community,即使它存在于 Github 中。

class Metrics(Callback):

    def __init__(self, val_data, batch_size = 20):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size

    def on_train_begin(self, logs={}):
        print(self.validation_data)
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
        batches = len(self.validation_data)
        total = batches * self.batch_size

        val_pred = np.zeros((total,1))
        val_true = np.zeros((total))

        for batch in range(batches):
            xVal, yVal = next(self.validation_data)
            val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
            val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal

        val_pred = np.squeeze(val_pred)
        _val_f1 = f1_score(val_true, val_pred)
        _val_precision = precision_score(val_true, val_pred)
        _val_recall = recall_score(val_true, val_pred)

        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)

        return
Run Code Online (Sandbox Code Playgroud)

我将继续关注上面提到的 Github 问题,并将相应地更新答案。

希望这可以帮助。快乐学习!