szi*_*szi 6 python java machine-learning protocol-buffers tensorflow
我一直在尝试导入并使用Java中训练有素的模型(Tensorflow,Python).
我能够在Python中保存模型,但是当我尝试使用Java中的相同模型进行预测时遇到了问题.
在这里,您可以看到用于初始化,训练,保存模型的python代码.
在这里,您可以看到用于导入和预测输入值的Java代码.
我得到的错误信息是:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
Run Code Online (Sandbox Code Playgroud)
我相信,问题出在python代码中,但我无法找到它.
Java importGraphDef()函数只导入计算图(由tf.train.write_graphPython代码编写),它不加载训练变量的值(存储在检查点中),这就是为什么你会抱怨未初始化变量的错误.
另一方面,TensorFlow SavedModel格式包括有关模型的所有信息(图形,检查点状态,其他元数据),以及在Java中使用,您希望用它SavedModelBundle.load来创建使用训练变量值初始化的会话.
要从Python导出这种格式的模型,您可能需要查看相关问题将重新训练开始的SavedModel部署到Google Cloud ml引擎
在您的情况下,这应该类似于Python中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
Run Code Online (Sandbox Code Playgroud)
并调用via save_model(session, x, yhat)
然后在Java中加载模型使用:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
Run Code Online (Sandbox Code Playgroud)
希望有所帮助.
| 归档时间: |
|
| 查看次数: |
7810 次 |
| 最近记录: |