TensorFlow - 导入元图并使用它的变量

roi*_*hik 6 python variables neural-network deep-learning tensorflow

我正在使用TensorFlow v0.12训练分类CNN,然后想要使用训练模型为新数据创建标签.

在训练脚本结束时,我添加了以下代码行:

saver = tf.train.Saver()
save_path = saver.save(sess,'/home/path/to/model/model.ckpt')
Run Code Online (Sandbox Code Playgroud)

培训完成后,文件夹中出现的文件为:1.checkpoint ; 2. model.ckpt.data-00000-of-00001 ; 3. model.ckpt.index ; 4. model.ckpt.meta

然后我尝试使用.meta文件恢复模型.在本教程之后,我将以下行添加到我的分类代码中:

saver=tf.train.import_meta_graph(savepath+'model.ckpt.meta') #line1
Run Code Online (Sandbox Code Playgroud)

然后:

saver.restore(sess, save_path=savepath+'model.ckpt') #line2
Run Code Online (Sandbox Code Playgroud)

在更改之前,我需要再次构建图形,然后写入(而不是line1):

saver = tf.train.Saver()
Run Code Online (Sandbox Code Playgroud)

但是,删除图形构建,并使用line1以便还原它,引发了错误.错误是我在我的代码中使用了图中的变量,并且python没有识别它:

predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})
Run Code Online (Sandbox Code Playgroud)

python无法识别y_conv参数.有一种方法可以使用元图恢复变量吗?如果不是,如果我不能使用原始图形中的变量,那么这个恢复有什么帮助?

我知道这个问题不是那么清楚,但我很难用文字表达问题.对不起...

谢谢你的回答,感谢你的帮助!投资回报率.

Rob*_*cok 13

这是可能的,不要担心.假设您不想再触摸图形,请执行以下操作:

saver = tf.train.import_meta_graph('model/export/{}.meta'.format(model_name))
saver.restore(sess, 'model/export/{}'.format(model_name))
graph = tf.get_default_graph()       
y_conv = graph.get_operation_by_name('y_conv').outputs[0]
predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})
Run Code Online (Sandbox Code Playgroud)

但是,在构建图形然后引用它们时,首选方法是将ops添加到集合中.因此,在定义图形时,您将添加以下行:

tf.add_to_collection("y_conv", y_conv)
Run Code Online (Sandbox Code Playgroud)

然后在导入元图并恢复它之后,您将调用:

y_conv = tf.get_collection("y_conv")[0]
Run Code Online (Sandbox Code Playgroud)

它实际上是在文档中解释的 - 您链接的确切页面 - 但也许您错过了它.

顺便说一句,不需要.ckpt扩展,它可能会产生一些混乱,因为这是保存模型的旧方法.