从Python API而不是saved_model_cli中提取(或设置)输入/输出TF张量名称信息

Art*_*yom 5 python inference tensorflow tf.keras

我用 Keras/TF2.5 训练了一个简单的模型并将其保存为保存模型。

tf.saved_model.save(my_model,'/path/to/model')
Run Code Online (Sandbox Code Playgroud)

如果我通过检查它

saved_model_cli show --dir /path/to/model --tag_set serve --signature_def serving_default
Run Code Online (Sandbox Code Playgroud)

我得到这些输出/名称:

inputs['conv2d_input'] tensor_info:
  dtype: DT_FLOAT
  shape: (-1, 32, 32, 1)
  name: serving_default_conv2d_input:0
outputs['dense'] tensor_info:
  dtype: DT_FLOAT
  shape: (-1, 2)
  name: StatefulPartitionedCall:0
Run Code Online (Sandbox Code Playgroud)

名称serving_default_conv2d_inputStatefulPartitionedCall实际上可以用于推断。

我想使用 python API 提取它们。如果我通过加载模型来查询它:

>>> m=tf.saved_model.load('/path/to/model')
>>> m.signatures['serving_default'].inputs[0].name
'conv2d_input:0'
>>> m.signatures['serving_default'].outputs[0].name
'Identity:0'
Run Code Online (Sandbox Code Playgroud)

我得到完全不同的名字。

问题:

  1. 如何从 python API 中serving_default_conv2d_input提取这些名称?StatefulPartitionedCall
  2. 或者,当我打电话时如何定义/修复名称tf.saved_model.save
  3. 这是什么:0意思?

还有附带问题:

如何通过 SavedModel 将 TF 模型部署到生产环境?

rvi*_*nas 3

显示的输入/输出张量名称saved_model_cli可以提取如下:

\n
from tensorflow.python.tools import saved_model_utils\n\nsaved_model_dir = \'/path/to/model\'\ntag_set = \'serve\'\nsignature_def_key = \'serving_default\'\n\n# 1. Load MetaGraphDef with saved_model_utils\nmeta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)\n\n# 2. Get input signature names\ninput_signatures = list(meta_graph_def.signature_def[signature_def_key].inputs.values())\ninput_names = [signature.name for signature in input_signatures]\nprint(input_names)  # [\'serving_default_conv2d_input:0\']\n\n# 3. Get output signature names\noutput_signatures = list(meta_graph_def.signature_def[signature_def_key].outputs.values())\noutput_names = [signature.name for signature in output_signatures]\nprint(output_names)  # [\'StatefulPartitionedCall:0\']\n
Run Code Online (Sandbox Code Playgroud)\n

关于 的含义:0op_name:0表示“称为 的运算的第 0 个输出的张量op_name”。因此,您可能会使用\xe2\x80\xa6:1多个输出来获取操作的输出,但许多操作是单输出,因此您将始终使用\xe2\x80\xa6:0它们(来源:@mrry's comment)。

\n