在 Tensorflow 中,我的模型基于预训练模型,我添加了更多变量并删除了预训练模型中的一些变量。当我从检查点文件中恢复变量时,我必须明确指定我添加到图中需要排除的所有变量。例如,我做了
exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
Run Code Online (Sandbox Code Playgroud)
有没有更简单的方法来做到这一点?即,只要变量不在检查点中,就不要尝试恢复。
您唯一可以做的就是首先使用与检查点中相同的模型,然后将检查点值恢复到相同的模型。恢复同一模型的变量后,您可以添加新层、删除现有层或更改层的权重。
但有一点很重要,你需要小心。添加新图层后,您需要初始化它们。如果使用tf.global_variables_initializer(),您将丢失重新加载的图层的值。因此,您应该只初始化未初始化的权重,您可以使用以下函数来实现此目的。
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
# for i in not_initialized_vars: # only for testing
# print(i.name)
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
6371 次 |
| 最近记录: |