会话恢复后,get_variable()不起作用

Ale*_*der 5 tensorflow

我尝试恢复会话并调用get_variable()以获取类型为tf.Variable的对象(根据此答案).它无法找到变量.重现案例的最小例子如下.

首先,创建一个变量并保存会话.

import tensorflow as tf

var = tf.Variable(101)

with tf.Session() as sess:
    with tf.variable_scope(''):
        scoped_var = tf.get_variable('scoped_var', [])

    with tf.variable_scope('', reuse=True):
        new_scoped_var = tf.get_variable('scoped_var', [])

    assert scoped_var is new_scoped_var
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print(sess.run(scoped_var))
    saver.save(sess, 'data/sess')
Run Code Online (Sandbox Code Playgroud)

这里get_variables有一个reuse=True工作正常的范围.然后,从文件中恢复会话并尝试获取变量.

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    for v in tf.get_collection('variables'):
        print(v.name)

    print(tf.get_collection(("__variable_store",)))
    # Oops, it's empty!

    with tf.variable_scope('', reuse=True):
        # the next line fails
        new_scoped_var = tf.get_variable('scoped_var', [])

    print("new_scoped_var: ", new_scoped_var)
Run Code Online (Sandbox Code Playgroud)

输出:

Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?
Run Code Online (Sandbox Code Playgroud)

我们可以看到,get_variable()找不到变量.和 ("__variable_store",)集合,在内部被使用get_variable(),是空的.

为什么get_variable失败?

Ste*_*ven 1

您可以尝试这个,而不是处理元图(如果您想修改图及其加载方式等,这会很有帮助)。

import tensorflow as tf

with tf.Session() as sess:
  with tf.variable_scope(''):
    scoped_var = tf.get_variable('scoped_var', [])

  with tf.variable_scope('', reuse=True):
    new_scoped_var = tf.get_variable('scoped_var', [])

  assert scoped_var is new_scoped_var
  saver = tf.train.Saver()
  path = tf.train.get_checkpoint_state('data/sess')
  if path is not None:
    saver.restore(sess, path.model_checkpoint_path)
  else:
    sess.run(tf.global_variables_initializer())

  print(sess.run(scoped_var))
  saver.save(sess, 'data/sess')

  #now continue to use as you normally would with a restored model
Run Code Online (Sandbox Code Playgroud)

主要区别是您在调用 saver.restore 之前已经设置了模型