Tensorflow`tf.layers.batch_normalization`不会将更新操作添加到`tf.GraphKeys.UPDATE_OPS`

Dav*_*rks 8 python tensorflow

以下代码(复制/粘贴runnable)说明了使用tf.layers.batch_normalization.

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

> []     # UPDATE_OPS collection is empty
Run Code Online (Sandbox Code Playgroud)

使用TF 1.5,文档(引用如下)明确指出在这种情况下UPDATE_OPS不应为空(https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization):

注意:训练时,需要更新moving_mean和moving_variance.默认情况下,更新操作被放入 tf.GraphKeys.UPDATE_OPS,因此需要将它们作为依赖项添加到train_op.例如:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)
Run Code Online (Sandbox Code Playgroud)

pfm*_*pfm 5

只需将您的代码更改为处于训练模式(通过将training标志设置为True),如引用中所述:

注意:训练时,需要更新moving_mean和moving_variance.默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op.

 import tensorflow as tf
 bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
 print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Run Code Online (Sandbox Code Playgroud)

将输出:

[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
 < tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)

和Gamma和Beta最终在TRAINABLE_VARIABLES集合中:

print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>, 
 <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)