使用经过训练的对象检测 API 模型和 TF 2 进行批量预测

Alb*_*rto 7 object detection prediction batch-processing tensorflow

我在 TPU 上使用 TF 2 的对象检测 API 成功训练了一个模型,该模型保存为 .pb(SavedModel 格式)。然后我使用它重新加载它tf.saved_model.load,它在使用转换为 shape 的张量的单个图像预测框时工作正常(1, w, h, 3)

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine
Run Code Online (Sandbox Code Playgroud)

问题是我需要进行批量预测以将其缩放到 50 万张图像,但该模型的输入签名似乎仅限于处理具有 shape 的数据(1, w, h, 3)。这也意味着我不能在 Tensorflow Serving 中使用批处理。我怎么解决这个问题?我可以只更改模型签名来处理批量数据吗?

所有工作(加载模型 + 预测)都在随对象检测 API 一起发布的官方容器内执行(来自此处

小智 5

我最近遇到了这个问题。当您使用exporter_main_v2.py将检查点文件转换为.pb文件时,它会调用exporter_lib_v2.py. 我认为在文件exporter_lib_v2.py这里)中,TF2 硬固定输入签名为 shape [1, None, None, 3]。我们必须将其更改为[None, None, None, 3]

需要修改这些线在该文件(138162170185从)1None。然后重建TF2 Object Detector API Repo(链接)并使用新构建的版本.pb再次导出。