Ham*_*d K 5 python java keras tensorflow
我有一个经过微调的 vgg 模型,我使用tensorflow.keras功能 API 创建了该模型,并使用tf.contrib.saved_model.save_keras_model保存了模型。因此模型以以下结构保存:assets文件夹,其中包含saved_model.json文件,saved_model.pb文件,以及variables文件夹,其中包含checkpoint,variables.data-00000-of-00001和variables.index。
我可以轻松地在 python 中加载模型并使用tf.contrib.saved_model.load_keras_model(saved_model_path)获得预测,但我不知道如何在 JAVA 中加载模型。我用谷歌搜索了很多,发现了这个 How to export Keras .h5 to tensorflow .pb? 导出为 pb 文件,然后按照此链接Loading in Java加载它。我无法冻结图表,而且我尝试使用 simple_save 但 tensorflow.keras 不支持 simple_save (AttributeError: module 'tensorflow.contrib.saved_model' has no attribute 'simple_save')。那么有人可以帮助我弄清楚在 JAVA 中加载我的模型(tensorflow.keras 功能 API 模型)需要哪些步骤。
我拥有的saved_model.pb 文件是否足够好,可以在JAVA 端加载?我需要创建输入/输出占位符吗?那我该如何导出呢?
我感谢您的帮助。
如果您有一个以 SavedModel 格式保存的模型(看起来您确实这样做了,并且类似的东西tf.contrib.saved_model.save_keras_model可以帮助创建),那么您可以在 Java 中用来SavedModelBundle.load加载和提供它。您不需要“冻结”模型。
您可能会发现以下有用:
但基本思想是您的代码将类似于:
try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
try (Tensor<?> input = makeInputTensor();
Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
// Use output
}
}
Run Code Online (Sandbox Code Playgroud)
其中"INPUT_TENSOR"和"OUTPUT_TENSOR"是 TensorFlow 图中输入和输出节点的名称。saved_model_cli安装 TensorFlow for Python 时安装的命令行工具可以显示模型中这些张量的名称。
请注意,正如另一位评论者建议的那样,使用 TensorFlow Java API 可能比使用 TensorFlow Lite 更适合服务器/桌面应用程序。这是因为 TensorFLow Lite 运行时虽然针对小型设备进行了优化(在内存占用等方面),但尚无法导出所有模型。虽然 TensorFlow Java API 使用完全相同的运行时,因此具有与 TensorFlow for Python 完全相同的功能。
希望有帮助。