tensorflow 2.0 自定义训练循环的学习率

yun*_*yun 8 python tensorflow

当我使用 tensorflow 2.0 自定义训练循环时,是否有任何函数或方法可以显示学习率?

这是张量流指南的示例:

def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)
Run Code Online (Sandbox Code Playgroud)

模型训练时如何从优化器中检索当前学习率?

如果您能提供任何帮助,我将不胜感激。:)

P S*_*ved 7

在 Tensorflow 2.1 中,Optimizer 类有一个未公开的方法_decayed_lr(请参阅此处的定义),您可以通过提供要强制转换为的变量类型在训练循环中调用该方法:

current_learning_rate = optimizer._decayed_lr(tf.float32)
Run Code Online (Sandbox Code Playgroud)

这里还有一个更完整的 TensorBoard 示例。

train_step_count = 0
summary_writer = tf.summary.create_file_writer('logs/')
def train_step(images, labels):
  train_step_count += 1
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  # optimizer._decayed_lr(tf.float32) is the current Learning Rate.
  # You can save it to TensorBoard like so:
  with summary_writer.as_default():
    tf.summary.scalar('learning_rate',
                      optimizer._decayed_lr(tf.float32),
                      step=train_step_count)
Run Code Online (Sandbox Code Playgroud)

  • 正是我所需要的。谢谢! (2认同)

zih*_*hao 6

在自定义训练循环设置中,您可以print(optimizer.lr.numpy())获取学习率。

如果您使用的是 keras api,您可以定义自己的回调来记录当前的学习率。

from tensorflow.keras.callbacks import Callback

class LRRecorder(Callback):
    """Record current learning rate. """
    def on_epoch_begin(self, epoch, logs=None):
        lr = self.model.optimizer.lr
        print("The current learning rate is {}".format(lr.numpy()))

# your other callbacks 
callbacks.append(LRRecorder())
Run Code Online (Sandbox Code Playgroud)

更新

w := w - (base_lr*m/sqrt(v))*grad = w - act_lr*grad 我们上面得到的学习率是base_lr。然而,act_lr在训练期间是适应性改变的。以 Adam 优化器为例,act_lrbase_lrm和决定vmv是参数的第一和第二动量。不同的参数有不同的mv值。因此,如果您想知道act_lr,则需要知道变量的名称。例如,你想知道act_lr变量的Adam/dense/kernel,你可以像这样访问mv

for var in optimizer.variables():
  if 'Adam/dense/kernel/m' in var.name:
    print(var.name, var.numpy())

  if 'Adam/dense/kernel/v' in var.name:
    print(var.name, var.numpy())
Run Code Online (Sandbox Code Playgroud)

然后您可以act_lr使用上述公式轻松计算。