tensorflow.keras 在 python 中保存模型并在 Java 中加载

Ham*_*d K 5 python java keras tensorflow

我有一个经过微调的 vgg 模型,我使用tensorflow.keras功能 API 创建了该模型,并使用tf.contrib.saved_model.save_keras_model保存了模型。因此模型以以下结构保存:assets文件夹,其中包含saved_model.json文件,saved_model.pb文件,以及variables文件夹,其中包含checkpointvariables.data-00000-of-00001variables.index

我可以轻松地在 python 中加载模型并使用tf.contrib.saved_model.load_keras_model(saved_model_path)获得预测,但我不知道如何在 JAVA 中加载模型。我用谷歌搜索了很多,发现了这个 How to export Keras .h5 to tensorflow .pb? 导出为 pb 文件,然后按照此链接Loading in Java加载它。我无法冻结图表,而且我尝试使用 simple_save 但 tensorflow.keras 不支持 simple_save (AttributeError: module 'tensorflow.contrib.saved_model' has no attribute 'simple_save')。那么有人可以帮助我弄清楚在 JAVA 中加载我的模型(tensorflow.keras 功能 API 模型)需要哪些步骤。

我拥有的saved_model.pb 文件是否足够好,可以在JAVA 端加载?我需要创建输入/输出占位符吗?那我该如何导出呢?
我感谢您的帮助。

ash*_*ash 2

如果您有一个以 SavedModel 格式保存的模型(看起来您确实这样做了,并且类似的东西tf.contrib.saved_model.save_keras_model可以帮助创建),那么您可以在 Java 中用来SavedModelBundle.load加载和提供它。您不需要“冻结”模型。

您可能会发现以下有用:

但基本思想是您的代码将类似于:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
  try (Tensor<?> input = makeInputTensor();
       Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
  // Use output
  }
}
Run Code Online (Sandbox Code Playgroud)

其中"INPUT_TENSOR""OUTPUT_TENSOR"是 TensorFlow 图中输入和输出节点的名称。saved_model_cli安装 TensorFlow for Python 时安装的命令行工具可以显示模型中这些张量的名称。

请注意,正如另一位评论者建议的那样,使用 TensorFlow Java API 可能比使用 TensorFlow Lite 更适合服务器/桌面应用程序。这是因为 TensorFLow Lite 运行时虽然针对小型设备进行了优化(在内存占用等方面),但尚无法导出所有模型。虽然 TensorFlow Java API 使用完全相同的运行时,因此具有与 TensorFlow for Python 完全相同的功能。

希望有帮助。