NotImplementedError:学习率计划必须覆盖 get_config

has*_*ooq 6 python machine-learning transformer-model keras tensorflow2.0

我已经使用 tf.keras 创建了一个自定义计划,并且在保存模型时遇到了这个错误:

NotImplementedError:学习率计划必须覆盖 get_config

这个类看起来像这样:

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        config = {
            'd_model':self.d_model,
            'warmup_steps':self.warmup_steps

        }
        base_config = super(CustomSchedule, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
Run Code Online (Sandbox Code Playgroud)

Mon*_* Eb 6

当您使用自定义子类模型时,保存模型架构有点棘手。相反,使用 Model.save_weights() 仅保存权重会更容易。

如果将代码更改为此,您将不会看到该错误:

  def get_config(self):
    config = {
    'd_model': self.d_model,
    'warmup_steps': self.warmup_steps,

     }
    return config
Run Code Online (Sandbox Code Playgroud)