我训练了一个模型,然后通过冻结该模型创建一个.pb文件.所以,我的问题是如何从.pb文件中获取权重,或者我必须为获取权重做更多的处理
@mrry,请指导我.
Kri*_*ist 17
我们先从.pb文件中加载图形.
import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
Run Code Online (Sandbox Code Playgroud)
现在,当你冻结的图形.pb文件的变量转换为Const输入,哪些是trainabe变量也将被存储为权重Const的.pb文件.graph_nodes包含图中的所有节点.但我们对所有Const类型节点感兴趣.
wts = [n for n in graph_nodes if n.op=='Const']
Run Code Online (Sandbox Code Playgroud)
每个元素wts都是NodeDef类型.它有几个属性,如name,op等.值可以提取如下 -
from tensorflow.python.framework import tensor_util
for n in wts:
print "Name of the node - %s" % n.name
print "Value - "
print tensor_util.MakeNdarray(n.attr['value'].tensor)
Run Code Online (Sandbox Code Playgroud)
希望这能解决您的疑虑.
| 归档时间: |
|
| 查看次数: |
9882 次 |
| 最近记录: |