Tensorflow:保存和重新分配会话-多个变量

Tom*_*Tom 1 python tensorflow

给出以下代码:

import tensorflow as tf

with tf.Session() as sess:
    var = tf.Variable(42, name='var')
    sess.run(tf.global_variables_initializer())
    tf.train.export_meta_graph('file.meta')

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('file.meta')
    print sess.run(var)
Run Code Online (Sandbox Code Playgroud)

我在一行上saver = tf.train.import_meta_graph('file.meta')说一个错误ValueError: At least two variables have the same name: var

我该如何解决?导入元图时,是否仍要覆盖计算图?

编辑:

我已经到达以下代码:

import tensorflow as tf

file_name = "./file"

with tf.Session() as sess:
    var = tf.Variable(42, name='my_var')
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    saver.save(sess,file_name)
    saver.export_meta_graph(file_name + '.meta')

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(file_name + '.meta')
    saver.restore(sess, file_name)
    print(sess.run(var))

    # new code that fails:
    saver = tf.train.Saver()
    saver.save(sess,file_name)
    saver.export_meta_graph(file_name + '.meta')
Run Code Online (Sandbox Code Playgroud)

这将为打印正确的值var,但是当我第二次保存图形时,我得到了相同的原始错误:ValueError: At least two variables have the same name: var

mar*_*ars 5

在这种情况下,您要将变量加载到已经定义了变量的默认图中。因此,您将需要在导入之前重置TensorFlow图。

使用进行此操作tf.reset_default_graph()。导入之前。查看“ 导出和导入元图”下的“在默认图中导入”

当然,你将不得不重新定义变量var使用tf.get_variable()。试试这个代码,

import tensorflow as tf

with tf.Session() as sess:
    var = tf.Variable(42, name='var')
    sess.run(tf.global_variables_initializer())
    tf.train.export_meta_graph('file.meta')
tf.reset_default_graph()
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('file.meta')
    var = tf.global_variables()[0]
    sess.run(tf.initialize_all_variables())
    print sess.run(var)
Run Code Online (Sandbox Code Playgroud)

您的中间代码不起作用的原因是tf.get_variable()正在创建一个随机初始化的新变量。确保先做tf.get_variable_scope().reuse_variables()。看一下理解tf.get_variable()

不幸的是,使用创建的变量tf.Variable()不能tf.get_variable()直接与重用。查看此注释和此注释以确切了解原因。因此,如果您希望将来重用该变量,则需要使用它tf.get_variable()创建变量。