Tensorflow变量范围:如果存在变量则重用

hol*_*lee 39 python tensorflow

我想要一段代码,如果它不存在,则在一个范围内创建一个变量,如果它已经存在,则访问该变量.我需要它是相同的代码,因为它将被多次调用.

但是,Tensorflow需要我指定是否要创建或重用变量,如下所示:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])
Run Code Online (Sandbox Code Playgroud)

我怎样才能弄清楚是否自动创建或重用它?即,我希望上面两个代码块相同并运行程序.

rvi*_*nas 33

ValueErrorget_variable()创建新变量并且未声明形状时,或者在变量创建期间违反重用时,会引发A.因此,你可以试试这个:

def get_scope_variable(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as scope:
        try:
            v = tf.get_variable(var, shape)
        except ValueError:
            scope.reuse_variables()
            v = tf.get_variable(var)
    return v

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
Run Code Online (Sandbox Code Playgroud)

请注意,以下内容也适用:

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
Run Code Online (Sandbox Code Playgroud)

UPDATE.新API现在支持自动重用:

def get_scope_variable(scope, var, shape=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        v = tf.get_variable(var, shape)
    return v
Run Code Online (Sandbox Code Playgroud)

  • @holdenlee是的,只需将范围设置为'''`(即空字符串) (4认同)
  • 有没有办法用顶级范围(即没有范围)来做到这一点? (2认同)

Zho*_*ang 13

虽然使用"try ... except ..."子句可行,但我认为更优雅和可维护的方法是将变量初始化过程与"重用"过程分开.

def initialize_variable(scope_name, var_name, shape):
    with tf.variable_scope(scope_name) as scope:
        v = tf.get_variable(var_name, shape)
        scope.reuse_variable()

def get_scope_variable(scope_name, var_name):
    with tf.variable_scope(scope_name, reuse=True):
        v = tf.get_variable(var_name)
    return v
Run Code Online (Sandbox Code Playgroud)

因为我们通常只需要初始化变量,但是多次重用/共享它,将两个进程分开使代码更清晰.同样,这样,我们不需要每次都通过"try"子句检查变量是否已经创建.


小智 13

新的AUTO_REUSE选项可以解决问题.

tf.variable_scope API文档:if reuse=tf.AUTO_REUSE,如果它们不存在,我们创建变量,否则返回它们.

共享变量AUTO_REUSE的基本示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2
Run Code Online (Sandbox Code Playgroud)