以下代码(复制/粘贴runnable)说明了使用tf.layers.batch_normalization.
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
> [] # UPDATE_OPS collection is empty
Run Code Online (Sandbox Code Playgroud)
使用TF 1.5,文档(引用如下)明确指出在这种情况下UPDATE_OPS不应为空(https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization):
注意:训练时,需要更新moving_mean和moving_variance.默认情况下,更新操作被放入
tf.GraphKeys.UPDATE_OPS,因此需要将它们作为依赖项添加到train_op.例如:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
Run Code Online (Sandbox Code Playgroud)
只需将您的代码更改为处于训练模式(通过将training标志设置为True),如引用中所述:
注意:训练时,需要更新moving_mean和moving_variance.默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op.
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Run Code Online (Sandbox Code Playgroud)
将输出:
[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>,
< tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)
和Gamma和Beta最终在TRAINABLE_VARIABLES集合中:
print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)