如何在Tensorflow r12中通过文件名恢复模型?

Tay*_*ers 12 tensorflow

我已经运行了分布式mnist示例:https: //github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/tools/dist_test/python/mnist_replica.py

虽然我已经设定了

saver = tf.train.Saver(max_to_keep=0)

在之前的版本中,如r11,我能够遍历每个检查点模型并评估模型的精度.这给了我一个精确度与全局步骤(或迭代)进度的图表.

在r12之前,tensorflow检查点模型保存在两个文件中,model.ckpt-1234并且model-ckpt-1234.meta.可以通过传递model.ckpt-1234文件名来恢复模型saver.restore(sess,'model.ckpt-1234').

然而,我注意到,在R12中,现在有三个输出文件model.ckpt-1234.data-00000-of-000001,model.ckpt-1234.indexmodel.ckpt-1234.meta.

我看到恢复文档说/train/path/model.ckpt应该给出一个路径来恢复而不是文件名.有没有办法一次加载一个检查点文件来评估它?我试图传递model.ckpt-1234.data-00000-of-000001,model.ckpt-1234.indexmodel.ckpt-1234.meta文件,但得到这样的错误如下:

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open logdir/2016-12-08-13-54/model.ckpt-0.data-00000-of-00001: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

NotFoundError (see above for traceback): Tensor name "hid_b" not found in checkpoint files logdir/2016-12-08-13-54/model.ckpt-0.index [[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open logdir/2016-12-08-13-54/model.ckpt-0.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

我在OSX Sierra上运行,通过pip安装了tensorflow r12.

任何指导都会有所帮助.

谢谢.

Yua*_* Ma 8

我也使用了Tensorlfow r0.12,我认为保存和恢复模型没有任何问题.以下是一个简单的代码,您可以尝试:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.

  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
Run Code Online (Sandbox Code Playgroud)

虽然在r0.12中,检查点存储在多个文件中,但您可以使用公共前缀(在您的情况下为"model.ckpt")来恢复它.


小智 5

R12改变了检查点格式.您应该以旧格式保存模型.

import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
...
saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V1)
saver.save(sess, './model.ckpt', global_step = step)
Run Code Online (Sandbox Code Playgroud)

根据TensorFlow v0.12.0 RC0的发布说明:

新的检查点格式成为tf.train.Saver中的默认格式.旧的V1检查点继续可读; 由write_version参数控制,tf.train.Saver现在默认以新的V2格式写出.它显着降低了恢复期间所需的峰值内存和延迟.

我的博客中查看详情.


Tay*_*ers 1

好的,我可以回答我自己的问题。我发现我的Python脚本在我的路径中添加了一个额外的“/”,所以我正在执行:saver.restore(sess,'/path/to/train//model.ckpt-1234')

不知何故,这导致了张量流的问题。

当我删除它时,调用: saver.restore(sess,'/path/to/trian/model.ckpt-1234')

它按预期工作。