mul*_*Pro 5 python java keras tensorflow
我在 Python 中使用 Keras 训练了一个模型,并且我想在 Java 程序中使用该训练后的模型。我原本打算在Java中直接使用Keras模型,但似乎Keras 2.0还没有得到很好的支持。因此,我将 Keras 模型(存储在 .h5 中)转换为张量流模型(存储在 .pb 中)。现在我想在我的 Java 代码中使用这个模型。但是,我需要 3 个字符串才能成功完成此操作:
我几乎不知道如何找到这些字符串。此时我无法对模型进行太多修改,特别是因为 Tensorflow 2.0 已删除get_session(),这意味着我需要使用 Tensorflow 1.0,这在从 Keras 2.0 加载模型时不断出现错误。我能够列出我的模型的所有操作,但我不知道这近 100 个操作中哪一个是正确的。我也不知道 metagraphdef 的标签。
我如何找到这 3 条信息?
如果您使用 pip(或类似的东西,如 conda 等)安装 TensorFlow,它应该附带该saved_model_cli实用程序。
您可以使用它从导出的模型中获得一些见解:
saved_model_cli show --dir <model_dir> --tag_set <tag> --signature_def <signature>
Run Code Online (Sandbox Code Playgroud)
在指南中查找更多信息。
这是我的模型之一的结果:
The given SavedModel SignatureDef contains the following input(s):
inputs['float32_Input'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 118)
name: serving_default_float32_Input:0
inputs['uint8_Input'] tensor_info:
dtype: DT_UINT8
shape: (-1, 583)
name: serving_default_uint8_Input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['tf_op_layer_ExpandDims'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: PartitionedCall:0
Method name is: tensorflow/serving/predict
Run Code Online (Sandbox Code Playgroud)
float32_Input, uint8_Input,ExpandDims是我在 Python 中的层的名称。要在 Java 中使用它,我必须使用名称:serving_default_float32_Input、serving_default_float32_Input和PartitionedCall。
| 归档时间: |
|
| 查看次数: |
843 次 |
| 最近记录: |