Chi*_*tan 5 python opencv tensorflow attention-model
我想用OpenCV-DNN包装注意力 OCR模型以增加推理时间。我正在使用官方 TF models repo中的 TF 代码。
对于用 OpenCV-DNN 包装 TF 模型,我参考了这段代码。需要cv2.dnn.readNetFromTensorflow()“冻结图”和“图结构”来读取 TF 模型。
我使用此代码片段从元检查点文件导入结构并将图形结构保存在文件中.pbtxt。
# load graph from meta file
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("attention_ocr_2017_08_09/model_demo_inference.ckpt.meta")
# restore graph structure, variables in session's graph
sess = tf.Session()
imported_meta.restore(sess, 'attention_ocr_2017_08_09/model_demo_inference.ckpt')
# write graph structure to a pbtxt file
tf.train.write_graph(sess.graph_def, './', 'train_attention.pbtxt', as_text=True)
Run Code Online (Sandbox Code Playgroud)
冻结图形,代码如下:
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph('train_attention.pbtxt', '', False, \
'attention_ocr_2017_08_09/model_demo_inference.ckpt', \
'AttentionOcr_v1_1/Softmax', \
'save/restore_all', 'save/Const:0', 'frozen_model.pb', True, "")
Run Code Online (Sandbox Code Playgroud)
最终代码使用函数中的pbtxt和文件。pbcv2.dnn.readNetFromTensorflow()
# Wrap TF model in OpenCV DNN
import cv2
FROZEN_GRAPH = "frozen_model.pb"
PB_TXT = "train_attention.pbtxt"
img = cv2.imread('testdata/fsns_train_00.png')
blob = cv2.dnn.blobFromImage(img,1)
net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
out = net.forward()
out
Run Code Online (Sandbox Code Playgroud)
遇到的错误是:
---------------------------------------------------------------------------
error Traceback (most recent call last)
<ipython-input-128-09e46e8b88ed> in <module>
9 blob = cv2.dnn.blobFromImage(img,1)
10
---> 11 net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
12 out = net.forward()
13 out
error: OpenCV(4.0.0) /Users/travis/build/skvark/opencv-python/opencv/modules/dnn/src/
tensorflow/tf_io.cpp:54: error: (-2:Unspecified error)
FAILED: ReadProtoFromTextFile(param_file, param).
Failed to parse GraphDef file: train_attention.pbtxt in function 'ReadTFNetParamsFromTextFileOrDie'
Run Code Online (Sandbox Code Playgroud)
注意:输出节点名称是通过查看使用以下命令生成的图中的张量列表来手动设置的:
# get names of all tensors
def get_names(graph=sess.graph):
return [t.name for op in graph.get_operations() for t in op.values()]
l1 = get_names()
for ele in l1:
print(ele)
Run Code Online (Sandbox Code Playgroud)
我非常感谢 SO 社区提供的任何帮助。
就我而言,我试图.pbtxt通过我的 Google Colaboratory 访问存储在 github 存储库中的文件。我只需要该文件,因此我尝试使用命令访问它,而不是克隆整个存储库!wget。我做了:
https://github.com/<username>/somethingelse/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt。该文件似乎已下载,但是当我运行时cv2.dnn.readNetFromTensorflow,它向我抛出一个错误:
OpenCV 错误:ReadTFNetParamsFromTextFileOrDie 中的未指定错误(失败:ReadProtoFromTextFileTF(param_file, param)。无法解析 GraphDef 文件:ssd_mobilenet_v1_coco_11_06_2017/graph.pbtxt),文件 opencv-3.3.1/modules/dnn/src/tensorflow/tf_io.cpp,行72 opencv-3.3.1/modules/dnn/src/tensorflow/tf_io.cpp:72:错误:(-2)失败:ReadProtoFromTextFileTF(param_file,param)。无法解析 GraphDef 文件:ReadTFNetParamsFromTextFileOrDie 函数中的 ssd_mobilenet_v1_coco_11_06_2017/graph.pbtxt
我意识到我应该使用rawgithub中的文件来下载,如下:
!wget https://raw.githubusercontent.com/---somethingelse as part of link---/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt
Run Code Online (Sandbox Code Playgroud)
由于我没有使用 raw,因此文件以 HTML 格式下载,而预期的文件是.pbtxt.
进入github仓库中的文件位置->点击右上角的raw选项->获取这个raw页面的URL->使用
!wget https://raw.githubusercontent.com/ ---somethingelse as part of link---/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4330 次 |
| 最近记录: |