use*_*600 5 moving-average tensorflow
此实现(https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/MovingAverage)是否与 tensorflow train 模块(https://www.tensorflow.org/api_docs/python/)中的 ExponentialMovingAverage 相同tf/train/ExponentialMovingAverage)?
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
loss_obj = tf.keras.losses.CategoricalCrossentropy()
@tf.function
def train_step(inputs, outputs):
with tf.GradientTape() as tape:
start, end = model([inputs[0], inputs[1], inputs[2]], training=True)
start_truth, end_truth = tf.squeeze(outputs[0]), tf.squeeze(outputs[1])
start_loss = loss_obj(start_truth, start)
end_loss = loss_obj(end_truth, end)
total_loss = start_loss + end_loss
model_gradients = tape.gradient(total_loss, model.trainable_variables)
opt_op = optimizer.apply_gradients(zip(model_gradients, model.trainable_variables))
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
with tf.control_dependencies([opt_op]):
ema.apply(model.trainable_variables)
del tape
return total_loss, start_loss, end_loss
Run Code Online (Sandbox Code Playgroud)
是相同的
import tensorflow as tf
import tensorflow_addons as tfa
optimizer = tfa.optimizers.MovingAverage(Adam(learning_rate=5e-5))
loss_obj = tf.keras.losses.CategoricalCrossentropy()
@tf.function
def train_step(inputs, outputs):
with tf.GradientTape() as tape:
start, end = model([inputs[0], inputs[1], inputs[2]], training=True)
start_truth, end_truth = tf.squeeze(outputs[0]), tf.squeeze(outputs[1])
start_loss = loss_obj(start_truth, start)
end_loss = loss_obj(end_truth, end)
total_loss = start_loss + end_loss
model_gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(model_gradients, model.trainable_variables))
del tape
return total_loss, start_loss, end_loss
Run Code Online (Sandbox Code Playgroud)
小智 0
这些实现并不等效。您的第一个实现创建存储在 GraphKeys.MOVING_AVERAGE_VARIABLES 和 GraphKeys.ALL_VARIABLES 集合中的影子变量。此版本只能在 tf.compat.v1.estimator.Estimator 框架或 tf.compat.v1.train 上下文中使用。在第二个实现中,影子变量存储在优化器创建的槽中,并在调用 Optimizer.apply_gradients 期间更新。交换平均和非平均模型权重以进行评估和保存模型非常简单。
| 归档时间: |
|
| 查看次数: |
431 次 |
| 最近记录: |