TensorFlow,保存模型后为什么有3个文件?

Goi*_*Way 101 tensorflow

阅读完文档后,我保存了一个模型TensorFlow,这是我的演示代码:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)
Run Code Online (Sandbox Code Playgroud)

但在那之后,我发现有3个文件

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
Run Code Online (Sandbox Code Playgroud)

我无法通过恢复model.ckpt文件来恢复模型,因为没有这样的文件.这是我的代码

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)

那么,为什么有3个文件?

T.K*_*tel 106

试试这个:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)

TensorFlow保存方法保存三种文件,因为它将图形结构变量值分开存储.该.meta文件描述了已保存的图形结构,因此您需要在恢复检查点之前导入它(否则它不知道保存的检查点值对应的变量).

或者,你可以这样做:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)

即使没有命名文件model.ckpt,在恢复时仍然会通过该名称引用保存的检查点.从saver.py源代码:

用户只需要与用户指定的前缀进行交互...而不是任何物理路径名.

  • @ ajfbiw.s .meta存储图结构,.data存储图中每个变量的值,.index标识checkpiont.所以在上面的例子中:import_meta_graph使用.meta,saver.restore使用.data和.index (22认同)
  • 有谁知道`00000`和'00001`数字是什么意思?在`variables.data - ????? - of - ?????`文件 (4认同)

Gua*_*Liu 51

  • 元文件:描述保存的图形结构,包括GraphDef,SaverDef等; 然后申请tf.train.import_meta_graph('/tmp/model.ckpt.meta'),将恢复SaverGraph.

  • index file:它是一个字符串字符串不可变表(tensorflow :: table :: Table).每个键都是张量的名称,其值是序列化的BundleEntryProto.每个BundleEntryProto描述一个张量的元数据:哪个"数据"文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等.

  • 数据文件:它是TensorBundle集合,保存所有变量的值.


小智 5

我正在从Word2Vec tensorflow教程中恢复经过训练的单词嵌入。

如果您创建了多个检查点:

例如,创建的文件如下所示

型号.ckpt-55695.data-00000-of-00001

型号.ckpt-55695.index

型号.ckpt-55695.meta

尝试这个

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')
Run Code Online (Sandbox Code Playgroud)

调用restore_session()时:

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
Run Code Online (Sandbox Code Playgroud)