小编Bra*_*vis的帖子

如何从对象检测中加载已保存的模型以进行推理?

我是 Tensorflow 的新手,并且一直在使用 Tensorflow 对象检测 API 对 SSD 进行实验。我可以成功训练一个模型,但默认情况下,它只保存最后 n 个检查点。我想改为保存损失最低的最后 n 个检查点(我假设这是最好的指标)。

我找到了 tf.estimator.BestExporter,它导出了一个 saved_model.pb 和变量。但是,我还没有弄清楚如何加载保存的模型并对其进行推理。在 checkpoiont 上运行 models/research/object_detection/export_inference_graph.py 后,我可以轻松加载检查点并使用对象检测 jupyter notebook 对其运行推理:https : //github.com/tensorflow/models/blob/master/research /object_detection/object_detection_tutorial.ipynb

我找到了有关加载保存模型的文档,并且可以加载这样的图表:

with tf.Session(graph=tf.Graph()) as sess:
        tags = [tag_constants.SERVING]
        meta_graph = tf.saved_model.loader.load(sess, tags, PATH_TO_SAVED_MODEL)
        detection_graph = tf.get_default_graph()
Run Code Online (Sandbox Code Playgroud)

但是,当我将该图与上述 jupyter 笔记本一起使用时,出现错误:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-17-9e48f0d04df2> in <module>
      7   image_np_expanded = np.expand_dims(image_np, axis=0)
      8   # Actual detection.
----> 9   output_dict = run_inference_for_single_image(image_np, detection_graph)
     10   # Visualization of the results of a …
Run Code Online (Sandbox Code Playgroud)

tensorflow

4
推荐指数
1
解决办法
4123
查看次数

标签 统计

tensorflow ×1