Tensorflow Metagraph Fundamentals

Ron*_*hen 7 tensorflow

我想训练我的Tensorflow模型,冻结快照,然后使用新的输入数据以前馈模式(无需进一步培训)运行它.问题:

  1. tf.train.export_meta_graphtf.train.import_meta_graph正确的工具吗?
  2. 我是否需要collection_list在快照中包含我想要包含的所有变量的名称?(对我来说最简单的就是包含所有内容.)
  3. Tensorflow文档说:" 如果未collection_list指定,则将导出模型中的所有集合." 这是否意味着如果我没有指定变量,collection_list那么模型中的所有变量都会被导出,因为它们位于默认集合中?
  4. 该Tensorflow文件说:" 为了一个Python对象序列化和从MetaGraphDef,Python的类必须实现to_proto()和from_proto()方法,并使用register_proto_function系统注册. "这是否意味着,to_proto()from_proto()绝只添加到我已定义并希望导出的类中?如果我只使用标准的Python数据类型(int,float,list,dict)那么这是无关紧要的吗?

提前致谢.

kaf*_*man 7

有点晚了,但我仍然会尝试回答。

  1. tf.train.export_meta_graphtf.train.import_meta_graph这个正确的工具?

我会这样说。请注意,tf.train.export_meta_graph当您通过保存模型时会隐式调用tf.train.Saver。要点是:

# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
    ...
    # save graph and variables
    # if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
    saver.save(sess, save_path, global_step)
Run Code Online (Sandbox Code Playgroud)

然后还原:

save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)
Run Code Online (Sandbox Code Playgroud)

请注意,除了调用之外,tf.train.import_meta_graph您还可以首先调用用于创建模型的原始代码。但是,我认为使用它会更优雅,import_meta_graph因为即使您无权访问创建模型的代码,也可以还原模型。


  1. 我是否需要collection_list在快照中包含我要包含的所有变量的名称?(对我来说,最简单的方法就是包括所有内容。)

否。但是问题有点令人困惑:collection_listin export_meta_graph并不意味着是变量列表,而是集合(即字符串键列表)。

集合非常方便,例如,所有可训练的变量都会自动包含在集合中tf.GraphKeys.TRAINABLE_VARIABLES,您可以通过调用以下代码获取:

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
Run Code Online (Sandbox Code Playgroud)

要么

tf.trainable_variables()  # defaults to the default graph
Run Code Online (Sandbox Code Playgroud)

如果在还原后您需要访问除可训练变量以外的其他中间结果,我发现将它们放入自定义集合中非常方便,如下所示:

...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)
Run Code Online (Sandbox Code Playgroud)

该集合会自动存储(除非您通过在的collection_list参数中省略此集合的名称来明确指定不这样做export_meta_graph)。因此,您可以input_在还原后简单地检索占位符,如下所示:

...
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)
    input_ = tf.get_collection_ref('my_custom_collection')[0]
Run Code Online (Sandbox Code Playgroud)
  1. Tensorflow文档说:“ 如果未collection_list指定,则将导出模型中的所有集合。 ”这是否意味着如果我未指定任何变量,collection_list则将导出模型中的所有变量,因为它们位于默认集合中?

是。再次注意微妙的细节,即collection_list是集合列表而不是变量。实际上,如果只希望保存某些变量,则可以在构造tf.train.Saver对象时指定这些变量。从文档tf.train.Saver.__init__

 """Creates a `Saver`.

    The constructor adds ops to save and restore variables.

    `var_list` specifies the variables that will be saved and restored. It can
    be passed as a `dict` or a list:

    * A `dict` of names to variables: The keys are the names that will be
      used to save or restore the variables in the checkpoint files.
    * A list of variables: The variables will be keyed with their op name in
      the checkpoint files.
Run Code Online (Sandbox Code Playgroud)
  1. Tensorflow的文档说:“ 为了使Python对象可以从MetaGraphDef序列化到Python或从MetaGraphDef序列化,Python类必须实现 to_proto()from_proto()方法,并使用register_proto_function在系统中注册它们。 ”这意味着to_proto()并且 from_proto()必须仅将其添加到我已经定义并想要导出?如果我仅使用标准Python数据类型(int,float,list,dict),那么这无关紧要吗?

我从未使用过此功能,但我会说您的解释是正确的。

  • 他们应将此答案替换为Tensorflow保存和恢复页面。也。“但是,我认为使用import_meta_graph更为优雅,因为即使您无权访问创建模型的代码,也可以还原模型。” 终于有些道理!我不知道为什么与此相反的主流观点是相反的... (2认同)