在TensorFlow中有没有办法初始化未初始化的变量?

Dan*_*ter 48 python tensorflow

在TensorFlow中初始化变量的标准方法是

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
Run Code Online (Sandbox Code Playgroud)

在运行了一段时间的学习后,我创建了一组新的变量,但是一旦我初始化它们,它就会重置我现有的所有变量.目前我的方法是保存我需要的所有变量,然后在tf.initalize_all_variables调用之后重新应用它们.这有效,但有点丑陋和笨重.我在文档中找不到这样的东西......

有没有人知道刚刚初始化未初始化变量的任何好方法?

mrr*_*rry 35

没有优雅的方法来枚举图中未初始化的变量.但是,如果你有机会到新的变量对象,让我们给他们打电话v_6,v_7v_8-你可以选择性地使用它们进行初始化tf.initialize_variables():

init_new_vars_op = tf.initialize_variables([v_6, v_7, v_8])
sess.run(init_new_vars_op)
Run Code Online (Sandbox Code Playgroud)

*可以使用试错过程来识别未初始化的变量,如下所示:

uninitialized_vars = []
for var in tf.all_variables():
    try:
        sess.run(var)
    except tf.errors.FailedPreconditionError:
        uninitialized_vars.append(var)

init_new_vars_op = tf.initialize_variables(uninitialized_vars)
# ...
Run Code Online (Sandbox Code Playgroud)

...但是,我不会宽恕这种行为:-).

  • Tensorflow 0.9具有可能有用的函数tf.report_uninitialized_variables(). (4认同)

Poi*_*oik 31

UPDATE: TensorFlow 0.9有一个新的方法"修复"所有这一切,但只有当您使用的是VariableScopereuse设置为True.tf.report_uninitialized_variables可以在一行中使用sess.run( tf.initialize_variables( list( tf.get_variable(name) for name in sess.run( tf.report_uninitialized_variables( tf.all_variables( ) ) ) ) ) )

或者通过指定您希望初始化的变量的能力更智能:

def guarantee_initialized_variables(session, list_of_variables = None):
    if list_of_variables is None:
        list_of_variables = tf.all_variables()
    uninitialized_variables = list(tf.get_variable(name) for name in
                                   session.run(tf.report_uninitialized_variables(list_of_variables)))
    session.run(tf.initialize_variables(uninitialized_variables))
    return unintialized_variables
Run Code Online (Sandbox Code Playgroud)

这仍然不如实际知道哪些变量是初始化和未初始化并且正确处理这些变量,但在类似误导的情况下optim(见下文)可能很难避免.

另请注意,tf.initialize_variables无法评估tf.report_uninitialized_variables,因此它们都必须在会话的上下文中运行才能工作.


这样做有一种不雅但简洁的方法.在介绍新变量之前运行temp = set(tf.all_variables())并之后运行sess.run(tf.initialize_variables(set(tf.all_variables()) - temp)).这些一起只会初始化分配临时值后创建的任何变量.

我一直在玩转学习,所以我想要一个快速的方法来做,但这是我能找到的最好方法.特别是当使用像AdamOptimizer这样的东西时,它不会让你轻松(或任何,我不确定)访问它使用的变量.所以以下实际显示在我的代码中.(我明确地初始化了新图层的变量,并在转移学习之前运行一次以显示初始错误.只是为了进行健全性检查.)

temp = set(tf.all_variables())
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#I honestly don't know how else to initialize ADAM in TensorFlow.
sess.run(tf.initialize_variables(set(tf.all_variables()) - temp))
Run Code Online (Sandbox Code Playgroud)

它解决了我所有的问题.

编辑: @Lifu_Huang的回答说明了解决问题的正确方法.从理论上讲,你应该使用tf.train.Optimizer.get_slot_namestf.train.Optimizer.get_slot:

optim = tf.train.AdadeltaOptimizer(1e-4)
loss = cross_entropy(y,yhat)
train_step = optim.minimize(loss)
sess.run(tf.initialize_variables([optim.get_slot(loss, name)
                                  for name in optim.get_slot_names()])
Run Code Online (Sandbox Code Playgroud)

然而,这给了我AttributeError: 'NoneType' object has no attribute 'initializer'.当我弄清楚我做错了什么时,我会做编辑,所以你不要犯错误.

  • 请注意,尽管有插槽,但优化器可以创建其他变量.对我来说,'AdamOptimizer`也会创建变量`[<tf.Variable'trode/beta1_power:0'shain =()dtype = float32_ref>,<tf.Variable'entimation/beta2_power:0'shape =()dtype = float32_ref> ]`它与可训练的var不对应,因此你不能将它们作为槽. (2认同)

Sal*_*ali 27

TF 没有完全符合你想要的功能,但是你可以轻松地写一个:

import tensorflow as tf

def initialize_uninitialized(sess):
    global_vars          = tf.global_variables()
    is_not_initialized   = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    print [str(i.name) for i in not_initialized_vars] # only for testing
    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
Run Code Online (Sandbox Code Playgroud)

在这里,我提取所有全局变量,迭代所有变量并检查它们是否已经初始化.在此之后,我得到我未初始化的变量列表初始化.我还打印了我要初始化的变量以用于调试目的.


您可以轻松验证它是否按预期工作:

a = tf.Variable(3, name='my_var_a')
b = tf.Variable(4, name='my_var_b')

sess = tf.Session()
initialize_uninitialized(sess)
initialize_uninitialized(sess)

c = tf.Variable(5, name='my_var_a') # the same name, will be resolved to different name
d = tf.Variable(6, name='my_var_d')
initialize_uninitialized(sess)

print '\n\n', sess.run([a, b, c, d])
Run Code Online (Sandbox Code Playgroud)

这将在初始化它们之前打印所有单元化变量,最后一个sess.run将确保说服所有变量都被初始化.


您还可以使用tf.report_uninitialized_variables()编写类似的功能.它的草图就在这里.