我想训练我的Tensorflow模型,冻结快照,然后使用新的输入数据以前馈模式(无需进一步培训)运行它.问题:
tf.train.export_meta_graph和tf.train.import_meta_graph正确的工具吗?collection_list在快照中包含我想要包含的所有变量的名称?(对我来说最简单的就是包含所有内容.)collection_list指定,则将导出模型中的所有集合." 这是否意味着如果我没有指定变量,collection_list那么模型中的所有变量都会被导出,因为它们位于默认集合中?to_proto()和from_proto()绝只添加到我已定义并希望导出的类中?如果我只使用标准的Python数据类型(int,float,list,dict)那么这是无关紧要的吗?提前致谢.
有点晚了,但我仍然会尝试回答。
- 是
tf.train.export_meta_graph和tf.train.import_meta_graph这个正确的工具?
我会这样说。请注意,tf.train.export_meta_graph当您通过保存模型时会隐式调用tf.train.Saver。要点是:
# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
# save graph and variables
# if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
saver.save(sess, save_path, global_step)
Run Code Online (Sandbox Code Playgroud)
然后还原:
save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
Run Code Online (Sandbox Code Playgroud)
请注意,除了调用之外,tf.train.import_meta_graph您还可以首先调用用于创建模型的原始代码。但是,我认为使用它会更优雅,import_meta_graph因为即使您无权访问创建模型的代码,也可以还原模型。
- 我是否需要
collection_list在快照中包含我要包含的所有变量的名称?(对我来说,最简单的方法就是包括所有内容。)
否。但是问题有点令人困惑:collection_listin export_meta_graph并不意味着是变量列表,而是集合(即字符串键列表)。
集合非常方便,例如,所有可训练的变量都会自动包含在集合中tf.GraphKeys.TRAINABLE_VARIABLES,您可以通过调用以下代码获取:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
Run Code Online (Sandbox Code Playgroud)
要么
tf.trainable_variables() # defaults to the default graph
Run Code Online (Sandbox Code Playgroud)
如果在还原后您需要访问除可训练变量以外的其他中间结果,我发现将它们放入自定义集合中非常方便,如下所示:
...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)
Run Code Online (Sandbox Code Playgroud)
该集合会自动存储(除非您通过在的collection_list参数中省略此集合的名称来明确指定不这样做export_meta_graph)。因此,您可以input_在还原后简单地检索占位符,如下所示:
...
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
input_ = tf.get_collection_ref('my_custom_collection')[0]
Run Code Online (Sandbox Code Playgroud)
- Tensorflow文档说:“ 如果未
collection_list指定,则将导出模型中的所有集合。 ”这是否意味着如果我未指定任何变量,collection_list则将导出模型中的所有变量,因为它们位于默认集合中?
是。再次注意微妙的细节,即collection_list是集合列表而不是变量。实际上,如果只希望保存某些变量,则可以在构造tf.train.Saver对象时指定这些变量。从文档tf.train.Saver.__init__:
"""Creates a `Saver`.
The constructor adds ops to save and restore variables.
`var_list` specifies the variables that will be saved and restored. It can
be passed as a `dict` or a list:
* A `dict` of names to variables: The keys are the names that will be
used to save or restore the variables in the checkpoint files.
* A list of variables: The variables will be keyed with their op name in
the checkpoint files.
Run Code Online (Sandbox Code Playgroud)
- Tensorflow的文档说:“ 为了使Python对象可以从MetaGraphDef序列化到Python或从MetaGraphDef序列化,Python类必须实现
to_proto()和from_proto()方法,并使用register_proto_function在系统中注册它们。 ”这意味着to_proto()并且from_proto()必须仅将其添加到我已经定义并想要导出?如果我仅使用标准Python数据类型(int,float,list,dict),那么这无关紧要吗?
我从未使用过此功能,但我会说您的解释是正确的。
| 归档时间: |
|
| 查看次数: |
1636 次 |
| 最近记录: |