如何从PredictResponse对象中检索float_val?

Dyl*_*dle 9 deep-learning tensorflow tensorflow-serving tensorflow-gpu

我正在对tensorflow服务模型进行预测,我将此PredictResponse对象作为输出返回:

结果:

outputs {
  key: "outputs"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 1
      }
      dim {
        size: 20
      }
    }
    float_val: 0.000343723397236
    float_val: 0.999655127525
    float_val: 3.96821117632e-11
    float_val: 1.20521548297e-09
    float_val: 2.09611101809e-08
    float_val: 1.46216549979e-09
    float_val: 3.87274603497e-08
    float_val: 1.83520256769e-08
    float_val: 1.47733780764e-08
    float_val: 8.00914179422e-08
    float_val: 2.29388191997e-07
    float_val: 6.27798826258e-08
    float_val: 1.08802950649e-07
    float_val: 4.39628813353e-08
    float_val: 7.87182985462e-10
    float_val: 1.31638898893e-07
    float_val: 1.42612295306e-08
    float_val: 3.0768305237e-07
    float_val: 1.12661648899e-08
    float_val: 1.68554503688e-08
  }
}
Run Code Online (Sandbox Code Playgroud)

我想把浮动值作为列表.或者,或者,返回argmax float_val的值/索引!

这是由以下生成的:

stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)result = stub.Predict(request,200.0)

感谢您的帮助.

Dyl*_*dle 9

答案是:

floats = result.outputs['outputs'].float_val
Run Code Online (Sandbox Code Playgroud)


Min*_*ark 8

您通常希望恢复具有形状的张量(不仅仅是一长串浮点数)。就是这样:

outputs_tensor_proto = result.outputs["outputs"]
shape = tf.TensorShape(outputs_tensor_proto.tensor_shape)
outputs = tf.constant(outputs_tensor_proto.float_val, shape=shape)
Run Code Online (Sandbox Code Playgroud)

如果您更喜欢获取 NumPy 数组,则只需替换最后一行:

outputs = np.array(outputs_tensor_proto.float_val).reshape(shape.as_list())
Run Code Online (Sandbox Code Playgroud)

如果您根本不想依赖 TensorFlow 库,出于某种原因:

outputs_tensor_proto = result.outputs["outputs"]
shape = [dim.size for dim in outputs_tensor_proto.tensor_shape.dim]
outputs = np.array(outputs_tensor_proto.float_val).reshape(shape)
Run Code Online (Sandbox Code Playgroud)

  • 这应该是公认的答案,因为它解决了一般情况。 (3认同)

Ale*_*sos 0

result["outputs"].float_val应该是一个Python列表