小编use*_*600的帖子

张量流插件移动平均线

此实现(https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/MovingAverage)是否与 tensorflow train 模块(https://www.tensorflow.org/api_docs/python/)中的 ExponentialMovingAverage 相同tf/train/ExponentialMovingAverage)?

import tensorflow as tf

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
loss_obj = tf.keras.losses.CategoricalCrossentropy()

@tf.function
def train_step(inputs, outputs):

    with tf.GradientTape() as tape:
        start, end = model([inputs[0], inputs[1], inputs[2]], training=True)
        start_truth, end_truth = tf.squeeze(outputs[0]), tf.squeeze(outputs[1])
        start_loss = loss_obj(start_truth, start)
        end_loss = loss_obj(end_truth, end)
        total_loss = start_loss + end_loss

    model_gradients = tape.gradient(total_loss, model.trainable_variables)
    opt_op = optimizer.apply_gradients(zip(model_gradients, model.trainable_variables))

    ema = tf.train.ExponentialMovingAverage(decay=0.9999)
    with tf.control_dependencies([opt_op]):
        ema.apply(model.trainable_variables)

    del tape
    return total_loss, start_loss, end_loss
Run Code Online (Sandbox Code Playgroud)

是相同的

import tensorflow as tf …
Run Code Online (Sandbox Code Playgroud)

moving-average tensorflow

5
推荐指数
1
解决办法
431
查看次数

标签 统计

moving-average ×1

tensorflow ×1