使用Tensorflow batch_norm函数获得较低的测试精度

Has*_*nat 4 python tensorflow

我在MNIST数据上使用Tensorflow 的官方批处理规范化(BN)函数(tf.contrib.layers.batch_norm())。我使用以下代码添加BN:

local4_bn = tf.contrib.layers.batch_norm(local4, is_training=True)
Run Code Online (Sandbox Code Playgroud)

在测试过程中,我在上面的代码行中更改了“ is_training = False”,并观察到仅20%的准确性。但是,如果我也将上述代码也用于批处理100张图像的测试(即,保持is_training = True),则它的精度约为99%。该观察表明,指数移动平均值和方差 batch_norm()可能不正确,或者我的代码中缺少某些内容。

任何人都可以回答有关上述问题的解决方案。

nes*_*uno 5

is_training=True仅由于批次大小为100而进行模型测试时,您可以获得〜99%的准确性。如果将批次大小更改为1,则准确性会降低。

这是由于以下事实:您正在计算输入批处理的指数移动平均值和方差,而不是使用这些值(批处理)归一化输出的图层。

batch_norm函数的参数variables_collections可帮助您在训练阶段存储计算的移动平均值和方差,并在测试阶段重新使用它们。

如果为这些变量定义一个集合,则该batch_norm层将在测试阶段使用它们,而不是计算新值。

因此,如果将批次归一化层定义更改为

local4_bn = tf.contrib.layers.batch_norm(local4, is_training=True, variables_collections=["batch_norm_non_trainable_variables_collection"])
Run Code Online (Sandbox Code Playgroud)

该层将计算出的变量存储到"batch_norm_non_trainable_variables_collection"集合中。

在测试阶段,当您传递is_training=False参数时,图层将重新使用在集合中找到的计算值。

请注意,移动平均值和方差不是可训练的参数,因此,如果仅将模型可训练参数保存在检查点文件中,则必须手动将存储的不可训练变量添加到先前定义的集合中。

创建Saver对象时可以执行以下操作:

saver = tf.train.Saver(tf.get_trainable_variables() + tf.get_collection_ref("batch_norm_non_trainable_variables_co??llection") + otherlistofvariables)
Run Code Online (Sandbox Code Playgroud)

在成瘾,因为批标准化可以限制被施加到(因为它限制了的值的范围),则应该使网络学习的参数的层的表现力gammabeta(在所描述的仿射变换系数),该允许网络学习仿射变换,从而增加层的表示能力。

您可以通过以下方式启用对功能True参数的这些参数设置的学习batch_norm

local4_bn = tf.contrib.layers.batch_norm(
    local4,
    is_training=True,
    center=True, # beta
    scale=True, # gamma
    variables_collections=["batch_norm_non_trainable_variables_collection"])
Run Code Online (Sandbox Code Playgroud)