如何从python中的.pb文件恢复Tensorflow模型?

viz*_*tiz 16 python android tensorflow

我有一个tensorflow .pb文件,我想加载到python DNN,恢复图形并获得预测.我这样做是为了测试创建的.pb文件是否可以使预测类似于普通的Saver.save()模型.

我的基本问题是,当我使用上面提到的.pb文件在Android上制作时,我得到了一个非常不同的预测值

我的.pb文件创建代码:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

所以我有两个主要问题:

  1. 如何将上述.pb文件加载到python Tensorflow模型?
  2. 为什么我在python和android中获得完全不同的预测值?

sah*_*ahu 22

以下代码将读取模型并打印出图中节点的名称.

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)
Run Code Online (Sandbox Code Playgroud)

你正在冻结图形,这就是为什么你得到不同的结果,基本上权重没有存储在你的模型中.您可以使用freeze_graph.py(链接)获取正确存储的图形.

  • 我上到`graph_def.ParseFromString(f.read())`DecodeError:错误解析消息 (8认同)

cay*_*lus 7

这是 tensorflow 2 的更新代码。

import tensorflow as tf

GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
   print("load graph")
   with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.compat.v1.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)
Run Code Online (Sandbox Code Playgroud)