如何使用 Keras 中的 Adam 优化器在每个时期打印学习率?

Zac*_*ach 8 python machine-learning neural-network deep-learning keras

因为当您使用自适应优化器时(调用时学习率计划会重置.fit()),在线学习与 Keras 无法很好地配合使用,所以我想看看是否可以手动设置它。然而,为了做到这一点,我需要找出最后一个时期的学习率。

也就是说,我如何打印每个时期的学习率?我想我可以通过回调来做到这一点,但似乎你每次都必须重新计算它,我不知道如何对亚当做到这一点。

我在另一个线程中找到了这个,但它只适用于 SGD:

class SGDLearningRateTracker(Callback):
    def on_epoch_end(self, epoch, logs={}):
        optimizer = self.model.optimizer
        lr = K.eval(optimizer.lr * (1. / (1. + optimizer.decay * optimizer.iterations)))
        print('\nLR: {:.6f}\n'.format(lr))
Run Code Online (Sandbox Code Playgroud)

jor*_*mit 6

我发现这个问题非常有帮助。回答您的问题的最小可行示例是:

def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
        return optimizer.lr
    return lr

optimizer = keras.optimizers.Adam()
lr_metric = get_lr_metric(optimizer)

model.compile(
    optimizer=optimizer,
    metrics=['accuracy', lr_metric],
    loss='mean_absolute_error', 
    )
Run Code Online (Sandbox Code Playgroud)


And*_*rey 5

我正在使用以下方法,该方法基于 @jorijnsmit 答案:

def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
        return optimizer._decayed_lr(tf.float32) # I use ._decayed_lr method instead of .lr
    return lr

optimizer = keras.optimizers.Adam()
lr_metric = get_lr_metric(optimizer)

model.compile(
    optimizer=optimizer,
    metrics=['accuracy', lr_metric],
    loss='mean_absolute_error', 
    )
Run Code Online (Sandbox Code Playgroud)

它与亚当一起工作。


Bes*_*low 5

对于仍然对这个话题感到困惑的每个人:

@Andrey 的解决方案有效,但只有当您设置学习率衰减时,您必须安排学习率在“n”纪元后降低自身,否则它将始终打印相同的数字(起始学习率),这是因为这个数字在训练过程中不会改变,你看不到学习率如何适应,因为 Adam 中的每个参数都有不同的学习率,在训练过程中会自我适应,但变量lr永远不会改变


Tus*_*pta 1

class MyCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.lr
        # If you want to apply decay.
        decay = self.model.optimizer.decay
        iterations = self.model.optimizer.iterations
        lr_with_decay = lr / (1. + decay * K.cast(iterations, K.dtype(decay)))
        print(K.eval(lr_with_decay))
Run Code Online (Sandbox Code Playgroud)

关注主题。

  • 这不是 Adam 使用的学习率。这是带有衰减的 SGD。 (2认同)