如何从Tensorflow对象检测API正确提供对象检测模型?

Xin*_*ang 5 object-detection tensorflow tensorflow-serving

我正在将Tensorflow对象检测API(github.com/tensorflow/models/tree/master/object_detection)与一个对象检测任务一起使用。现在我在使用Tensorflow Serving(tensorflow.github.io/serving/)训练的检测模型的服务上遇到问题。

1.我遇到的第一个问题是关于将模型导出到可服务文件。对象检测api包括导出脚本,因此我能够将ckpt文件转换为带变量的pb文件。但是,输出文件在“变量”文件夹中将没有任何内容。我虽然这是一个错误,并在Github上进行了报告,但似乎他们进行了实习,将变量转换为常量,这样就没有变量了。细节可以在这里找到。

我在导出保存的模型时使用的标志如下:

    CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
        --input_type image_tensor \
            --pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
                --checkpoint_path resnet_ckpt/model.ckpt-17586 \
                    --inference_graph_path serving_model/1 \
                      --export_as_saved_model True
Run Code Online (Sandbox Code Playgroud)

当我将--export_as_saved_model切换为False时,它在python中运行良好。

但是,我仍然在为模型服务方面遇到问题。

当我尝试跑步时:

~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>
Run Code Online (Sandbox Code Playgroud)

我有:

2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...
Run Code Online (Sandbox Code Playgroud)

我认为该模型未正确加载,因为它显示“指定的SavedModel没有变量;没有恢复检查点”。

但是由于我们已将所有变量都转换为常量,因此似乎是合理的。我不确定在这里。

2。我无法使用客户端调用服务器并无法对样本图像进行检测。

客户端脚本已在下面列出:

from __future__ import print_function
from __future__ import absolute_import

# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2


# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
                       'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
    (im_height, im_width, 3)).astype(np.uint8)

def main(_):
    host, port = FLAGS.server.split(':')
    channel = implementations.insecure_channel(host, int(port))
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

    # Send request
    request = predict_pb2.PredictRequest()
    image = Image.open(FLAGS.image)
    image_np = load_image_into_numpy_array(image)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    # Call GAN model to make prediction on the image
    request.model_spec.name = 'gan'
    request.model_spec.signature_name = 'predict_images'
    request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded))

    result = stub.Predict(request, 60.0)  # 60 secs timeout
    print(result)


if __name__ == '__main__':
    tf.app.run()
Run Code Online (Sandbox Code Playgroud)

为了匹配request.model_spec.signature_name = 'predict_images',我修改了对象检测api(github.com/tensorflow/models/blob/master/object_detection/exporter.py)中的exporter.py脚本,该脚本从第289行开始:

          signature_def_map={
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },
Run Code Online (Sandbox Code Playgroud)

至:

          signature_def_map={
          'predict_images': detection_signature,
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },
Run Code Online (Sandbox Code Playgroud)

由于我不知道如何调用默认签名密钥。

当我运行以下命令时:

bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>
Run Code Online (Sandbox Code Playgroud)

我收到以下错误消息:

    Traceback (most recent call last):
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
    tf.app.run()
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
    result = stub.Predict(request, 60.0)  # 60 secs timeout
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
    self._request_serializer, self._response_deserializer)
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
    raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")
Run Code Online (Sandbox Code Playgroud)

不太确定这里发生了什么。

最初,尽管我发现我的客户端脚本不正确,但是我发现AbortionError来自github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc。似乎在构建图形时出现此错误。所以可能是我遇到的第一个问题。

我是新手,所以我很困惑。我想一开始我可能是错的。有什么方法可以正确导出并服务于检测模型?任何建议都会有很大帮助!

小智 2

当前的导出器代码未正确填充签名字段。所以使用模型服务器提供服务是行不通的。对此表示歉意。更好地支持导出模型的新版本即将推出。它包括服务所需的一些重要修复和改进,尤其是在 Cloud ML Engine 上提供服务。如果您想尝试它的早期版本,请参阅github 问题。

对于“指定的 SavedModel 没有变量;没有恢复任何检查点。” 消息,由于您所说的确切原因,这是预期的,因为所有变量都转换为图中的常量。对于“FeedInputs:无法找到feed输出ToFloat:0”的错误,请确保在构建模型服务器时使用TF 1.2。