如何仅恢复 Tensorflow 中检查点中的变量?

use*_*379 5 tensorflow

在 Tensorflow 中,我的模型基于预训练模型,我添加了更多变量并删除了预训练模型中的一些变量。当我从检查点文件中恢复变量时,我必须明确指定我添加到图中需要排除的所有变量。例如,我做了

exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
Run Code Online (Sandbox Code Playgroud)

有没有更简单的方法来做到这一点?即,只要变量不在检查点中,就不要尝试恢复。

ene*_*ski 1

您唯一可以做的就是首先使用与检查点中相同的模型,然后将检查点值恢复到相同的模型。恢复同一模型的变量后,您可以添加新层、删除现有层或更改层的权重。

但有一点很重要,你需要小心。添加新图层后,您需要初始化它们。如果使用tf.global_variables_initializer(),您将丢失重新加载的图层的值。因此,您应该只初始化未初始化的权重,您可以使用以下函数来实现此目的。

def initialize_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    # for i in not_initialized_vars: # only for testing
    #    print(i.name)

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
Run Code Online (Sandbox Code Playgroud)