Tensorflow保存/恢复批量规范

ALe*_*eex 7 tensorflow batch-normalization

我在Tensorflow中训练了一个具有批量规范的模型.我想保存模型并将其恢复以供进一步使用.批量规范由.完成

def batch_norm(input, phase):
    return tf.layers.batch_normalization(input, training=phase)
Run Code Online (Sandbox Code Playgroud)

阶段True在训练False期间和测试期间.

这似乎只是在呼唤

saver = tf.train.Saver()
saver.save(sess, savedir + "ckpt")
Run Code Online (Sandbox Code Playgroud)

不会很好,因为当我恢复模型时,它首先说成功恢复.它还说Attempting to use uninitialized value batch_normalization_585/beta如果我只是在图中运行一个节点.这是否与正确保存模型或我错过的其他内容有关?

sim*_*o23 8

我还有"尝试使用未初始化的值batch_normalization_585/beta"错误.这是因为通过像这样用空括号声明保护程序:

         saver = tf.train.Saver() 
Run Code Online (Sandbox Code Playgroud)

保存程序将保存tf.trainable_variables()中包含的变量,这些变量不包含批量标准化的移动平均值.要将此变量包含在已保存的ckpt中,您需要执行以下操作:

         saver = tf.train.Saver(tf.global_variables())
Run Code Online (Sandbox Code Playgroud)

这保存了所有变量,因此非常耗费内存.或者您必须识别具有移动平均值或方差的变量,并通过声明它们来保存它们:

         saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)
Run Code Online (Sandbox Code Playgroud)