Java Tensorflow + Keras 相当于 model.predict()

ale*_*ger 3 java keras tensorflow tensorflow-serving

在 python 中,您可以简单地传递一个 numpy 数组以predict()从您的模型中获取预测。使用带有 的 Java 的等价物是SavedModelBundle什么?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 
Run Code Online (Sandbox Code Playgroud)

爪哇

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
Run Code Online (Sandbox Code Playgroud)

小智 5

TensorFlow Python 会自动将您的 NumPy 数组转换为tf.Tensor. 在 TensorFlow Java 中,您可以直接操作张量。

现在SavedModelBundle没有predict方法。您需要获取会话并运行它,使用SessionRunner并为其提供输入张量。

例如,基于下一代 TF Java ( https://github.com/tensorflow/java ),您的代码最终看起来像这样(请注意,我在这里做了很多假设,x_test_maxabs因为您的代码示例没有清楚地解释它的来源):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}
Run Code Online (Sandbox Code Playgroud)

如果您不确定图中输入/输出张量的名称是什么,您可以通过查看签名定义以编程方式获取:

model.metaGraphDef().getSignatureDefMap().get("serving_default")
Run Code Online (Sandbox Code Playgroud)

  • 我必须升级到 TF 2.x 才能获得 `model.metaGraphDef().getSignatureDefMap().get("serving_default")` 但除此之外它还有效! (2认同)