在Python中,我训练了TensorFlow LinearClassifier并将其保存为:
model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)
Run Code Online (Sandbox Code Playgroud)
通过使用TensorFlow Java API,我可以使用以下方法在Java中加载此模型:
model = SavedModelBundle.load(export_dir, "serve");
Run Code Online (Sandbox Code Playgroud)
看来我应该能够使用类似
model.session().runner().feed(???, ???).fetch(???, ???).run()
Run Code Online (Sandbox Code Playgroud)
但是我应该从图形中获取/获取哪些变量名称/数据以提供其功能并获取类的概率?据我所知,Java文档缺少此信息。
要馈送的节点的名称取决于parsing_serving_input_fn执行的操作,尤其是它们应该是Tensor所返回的对象的名称parsing_serving_input_fn。要获取的节点的名称将取决于您所预测的内容(关于model.predict()是否使用来自Python的模型的参数)。
也就是说,TensorFlow保存的模型格式的确包含了该模型的“签名”(即,可以提供或获取的所有Tensor的名称),作为可以提供提示的元数据。
在Python中,您可以使用以下方法加载保存的模型并列出其签名:
with tf.Session() as sess:
md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
print(sig)
Run Code Online (Sandbox Code Playgroud)
它将打印如下内容:
inputs {
key: "inputs"
value {
name: "input_example_tensor:0"
dtype: DT_STRING
tensor_shape {
dim {
size: -1
}
}
}
}
outputs {
key: "scores"
value {
name: "linear/binary_logistic_head/predictions/probabilities:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 2
}
}
}
}
method_name: "tensorflow/serving/classify"
Run Code Online (Sandbox Code Playgroud)
建议您使用Java进行以下操作:
Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()
Run Code Online (Sandbox Code Playgroud)
如果您的程序包含使用TensorFlow协议缓冲区生成的Java代码(打包在org.tensorflow:proto工件中),您也可以纯粹在Java中提取此信息,例如:
// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default";
final SignatureDef sig =
MetaGraphDef.parseFrom(model.metaGraphDef())
.getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
Run Code Online (Sandbox Code Playgroud)
您将必须添加:
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
Run Code Online (Sandbox Code Playgroud)
由于Java API和save-model-format有点新,因此文档中还有很大的改进空间。
希望能有所帮助。
| 归档时间: |
|
| 查看次数: |
1178 次 |
| 最近记录: |