use*_*974 11 python tensorflow
似乎是tf.train.init_from_checkpoint
初始化通过 创建的变量,tf.get_variable
而不是通过创建的变量tf.Variable
。
例如,让我们创建两个变量并保存它们:
import tensorflow as tf
tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, './model', global_step=0)
Run Code Online (Sandbox Code Playgroud)
如果我再次通过加载它们tf.train.Saver
,则一切正常:即使变量在此处初始化为零,变量也将重新加载为1:
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './model-0')
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 1.0 bar: 1.0
Run Code Online (Sandbox Code Playgroud)
但是,如果我使用tf.train.init_from_checkpoint
我得到
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 0.0 bar: 1.0
Run Code Online (Sandbox Code Playgroud)
bar
按预期设置为1,但foo
仍为0。
这是预期的行为吗?如果是这样,为什么?
是的,这是有意的。方法中描述了此行为_init_from_checkpoint
,该方法在加载要恢复的变量时迭代赋值映射。
for tensor_name_in_ckpt, current_var_or_name in sorted(
six.iteritems(assignment_map)):
var = None
Run Code Online (Sandbox Code Playgroud)
它首先设置要恢复的变量None
,如果满足几个条件之一,它将重置为当前变量名称。在这种特殊情况下,循环包含语句
if "/" in current_var_or_name
因此,它将从store_vars
之前创建的字典中加载变量。_init_from_checkpoint
它是在检查赋值映射中的当前变量是否为后立即创建的tf.Variable
,此时为 False。
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars
Run Code Online (Sandbox Code Playgroud)
store_vars
是由内部类创建的_VariableStore
,更准确地说,是由它的_get_default_variable_store()
方法创建的。该类用作get_variable
变量构造函数。由于tf.Variable
没有默认作用域,因此tf.get_variable
首先调用 tf.get_variable_scope(),它返回当前变量作用域。'foo' 超出了这个范围。此外,tf.Variable
每次调用时都会创建一个新变量,并且不允许共享。
store_vars
是从默认作用域成员构造的,因此,它仅包含“bar”变量,并foo
稍后使用 op 添加到变量集合中tf.Variable
。
但是,如果assignment_map
will contains {'foo':foo, 'bar':bar}
,则上述 for_init_from_checkpoint
将找到这些变量并加载它们。所以在这种情况下你的代码将输出foo: 1.0 bar: 1.0
您可以在https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py中找到代码
和 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py 另请参阅此答案What is the default variable_scope in Tensorflow?
归档时间: |
|
查看次数: |
1346 次 |
最近记录: |