tf.train.init_from_checkpoint不会初始化使用tf.Variable创建的变量

use*_*974 11 python tensorflow

似乎是tf.train.init_from_checkpoint初始化通过 创建的变量,tf.get_variable不是通过创建的变量tf.Variable

例如,让我们创建两个变量并保存它们:

import tensorflow as tf

tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saver.save(sess, './model', global_step=0)
Run Code Online (Sandbox Code Playgroud)

如果我再次通过加载它们tf.train.Saver,则一切正常:即使变量在此处初始化为零,变量也将重新加载为1:

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, './model-0')
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 1.0  bar: 1.0
Run Code Online (Sandbox Code Playgroud)

但是,如果我使用tf.train.init_from_checkpoint我得到

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 0.0  bar: 1.0
Run Code Online (Sandbox Code Playgroud)

bar按预期设置为1,但foo仍为0。

这是预期的行为吗?如果是这样,为什么?

Sha*_*rky 4

是的,这是有意的。方法中描述了此行为_init_from_checkpoint,该方法在加载要恢复的变量时迭代赋值映射。

 for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None
Run Code Online (Sandbox Code Playgroud)

它首先设置要恢复的变量None,如果满足几个条件之一,它将重置为当前变量名称。在这种特殊情况下,循环包含语句

if "/" in current_var_or_name

因此,它将从store_vars之前创建的字典中加载变量。_init_from_checkpoint它是在检查赋值映射中的当前变量是否为后立即创建的tf.Variable,此时为 False。

 if _is_variable(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(_is_variable(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars 
Run Code Online (Sandbox Code Playgroud)

store_vars是由内部类创建的_VariableStore,更准确地说,是由它的_get_default_variable_store()方法创建的。该类用作get_variable变量构造函数。由于tf.Variable没有默认作用域,因此tf.get_variable首先调用 tf.get_variable_scope(),它返回当前变量作用域。'foo' 超出了这个范围。此外,tf.Variable每次调用时都会创建一个新变量,并且不允许共享。

store_vars是从默认作用域成员构造的,因此,它仅包含“bar”变量,并foo稍后使用 op 添加到变量集合中tf.Variable

但是,如果assignment_mapwill contains {'foo':foo, 'bar':bar},则上述 for_init_from_checkpoint将找到这些变量并加载它们。所以在这种情况下你的代码将输出foo: 1.0 bar: 1.0

您可以在https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py中找到代码

https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py 另请参阅此答案What is the default variable_scope in Tensorflow?