如何使用tf.reset_default_graph()

Bos*_*sen 11 machine-learning tensorflow

每当我尝试使用时tf.reset_default_graph(),我都会收到此错误:IndexError: list index out of range或者``.我应该在哪部分代码中使用它?我什么时候应该使用它?

编辑:

我更新了代码,但错误仍然存​​在.

def evaluate():
    with tf.name_scope("loss"):
        global x # x is a tf.placeholder()
        xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))
        loss = tf.reduce_mean(xentropy, name="loss")

    with tf.name_scope("train"):
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)

    with tf.name_scope("exec"):
        with tf.Session() as sess:
            for i in range(1, 2):
                sess.run(tf.global_variables_initializer())
                sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
                print "Training " + str(i)
                saver = tf.train.Saver()
                saver.save(sess, "saved_models/testing")
                print "Model Saved."


def predict():
    with tf.name_scope("predict"):
        tf.reset_default_graph()
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')
            print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})


def main():
    print "Starting Program..."
    evaluate()
    writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())
    predict()
Run Code Online (Sandbox Code Playgroud)

如果我从更新的代码中删除tf.reset_default_graph(),我收到此错误: ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

根据我目前的理解,tf.reset_default_graph()会删除所有图形,因此我避免了上面提到的错误(ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used)

Sal*_*ali 16

最有可能的是你如何使用它:

import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
    tf.reset_default_graph()
Run Code Online (Sandbox Code Playgroud)

这是因为您在会话中使用它:

在tf.Session或tf.InteractiveSession处于活动状态时调用此函数将导致未定义的行为.在调用此函数后使用任何先前创建的tf.Operation或tf.Tensor对象将导致未定义的行为


当我在jupyter笔记本中进行实验时,它在测试阶段可能会有所帮助(至少对我而言).我从未在生产中使用它,也看不出它有什么用.这是我在笔记本中的例子.

import tensorflow as tf
# create some graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(...)
Run Code Online (Sandbox Code Playgroud)

现在我不再需要这些东西,但是如果我要创建另一个图形并尝试在tensorboard中可视化它,我会看到旧节点和新节点.我可以重新启动内核并只运行下一个单元格,但我可以这样做

tf.reset_default_graph()
# create a new graph
with tf.Session() as sess:
    print sess.run(...)
Run Code Online (Sandbox Code Playgroud)

在OP添加他的代码后编辑:

with tf.name_scope("predict"):
    tf.reset_default_graph()
Run Code Online (Sandbox Code Playgroud)

这是大概发生的事情.您的代码失败,因为tf.reset_default_graph()已经向图表添加了一些内容.虽然是这里面"添加一些东西到图形"你告诉TF完全删除图形,但不能因为它是忙于增加一些东西.

  • 所以基本上`tf.reset_default_graph()`只用于测试而不用于整天运行的应用程序? (2认同)