ale*_*ger 3 java keras tensorflow tensorflow-serving
在 python 中,您可以简单地传递一个 numpy 数组以predict()从您的模型中获取预测。使用带有 的 Java 的等价物是SavedModelBundle什么?
model = tf.keras.models.Sequential([
# layers go here
])
model.compile(...)
model.fit(x_train, y_train)
predictions = model.predict(x_test_maxabs) # <= This line
Run Code Online (Sandbox Code Playgroud)
SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
Run Code Online (Sandbox Code Playgroud)
小智 5
TensorFlow Python 会自动将您的 NumPy 数组转换为tf.Tensor. 在 TensorFlow Java 中,您可以直接操作张量。
现在SavedModelBundle没有predict方法。您需要获取会话并运行它,使用SessionRunner并为其提供输入张量。
例如,基于下一代 TF Java ( https://github.com/tensorflow/java ),您的代码最终看起来像这样(请注意,我在这里做了很多假设,x_test_maxabs因为您的代码示例没有清楚地解释它的来源):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
Run Code Online (Sandbox Code Playgroud)
如果您不确定图中输入/输出张量的名称是什么,您可以通过查看签名定义以编程方式获取:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1519 次 |
| 最近记录: |