当我使用 tensorflow 2.0 自定义训练循环时,是否有任何函数或方法可以显示学习率?
这是张量流指南的示例:
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
Run Code Online (Sandbox Code Playgroud)
模型训练时如何从优化器中检索当前学习率?
如果您能提供任何帮助,我将不胜感激。:)
在 Tensorflow 2.1 中,Optimizer 类有一个未公开的方法_decayed_lr(请参阅此处的定义),您可以通过提供要强制转换为的变量类型在训练循环中调用该方法:
current_learning_rate = optimizer._decayed_lr(tf.float32)
Run Code Online (Sandbox Code Playgroud)
这里还有一个更完整的 TensorBoard 示例。
train_step_count = 0
summary_writer = tf.summary.create_file_writer('logs/')
def train_step(images, labels):
train_step_count += 1
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# optimizer._decayed_lr(tf.float32) is the current Learning Rate.
# You can save it to TensorBoard like so:
with summary_writer.as_default():
tf.summary.scalar('learning_rate',
optimizer._decayed_lr(tf.float32),
step=train_step_count)
Run Code Online (Sandbox Code Playgroud)
在自定义训练循环设置中,您可以print(optimizer.lr.numpy())获取学习率。
如果您使用的是 keras api,您可以定义自己的回调来记录当前的学习率。
from tensorflow.keras.callbacks import Callback
class LRRecorder(Callback):
"""Record current learning rate. """
def on_epoch_begin(self, epoch, logs=None):
lr = self.model.optimizer.lr
print("The current learning rate is {}".format(lr.numpy()))
# your other callbacks
callbacks.append(LRRecorder())
Run Code Online (Sandbox Code Playgroud)
w := w - (base_lr*m/sqrt(v))*grad = w - act_lr*grad
我们上面得到的学习率是base_lr。然而,act_lr在训练期间是适应性改变的。以 Adam 优化器为例,act_lr由base_lr、m和决定v。m和v是参数的第一和第二动量。不同的参数有不同的m和v值。因此,如果您想知道act_lr,则需要知道变量的名称。例如,你想知道act_lr变量的Adam/dense/kernel,你可以像这样访问m和v,
for var in optimizer.variables():
if 'Adam/dense/kernel/m' in var.name:
print(var.name, var.numpy())
if 'Adam/dense/kernel/v' in var.name:
print(var.name, var.numpy())
Run Code Online (Sandbox Code Playgroud)
然后您可以act_lr使用上述公式轻松计算。