如何在Java中使用TensorFlow LinearClassifier

Jan*_*ers 1 java tensorflow

在Python中,我训练了TensorFlow LinearClassifier并将其保存为:

model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)
Run Code Online (Sandbox Code Playgroud)

通过使用TensorFlow Java API,我可以使用以下方法在Java中加载此模型:

model = SavedModelBundle.load(export_dir, "serve");
Run Code Online (Sandbox Code Playgroud)

看来我应该能够使用类似

model.session().runner().feed(???, ???).fetch(???, ???).run()
Run Code Online (Sandbox Code Playgroud)

但是我应该从图形中获取/获取哪些变量名称/数据以提供其功能并获取类的概率?据我所知,Java文档缺少此信息。

ash*_*ash 6

要馈送的节点的名称取决于parsing_serving_input_fn执行的操作,尤其是它们应该是Tensor所返回的对象的名称parsing_serving_input_fn。要获取的节点的名称将取决于您所预测的内容(关于model.predict()是否使用来自Python的模型的参数)。

也就是说,TensorFlow保存的模型格式的确包含了该模型的“签名”(即,可以提供或获取的所有Tensor的名称),作为可以提供提示的元数据。

在Python中,您可以使用以下方法加载保存的模型并列出其签名:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)
Run Code Online (Sandbox Code Playgroud)

它将打印如下内容:

inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"
Run Code Online (Sandbox Code Playgroud)

建议您使用Java进行以下操作:

Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()
Run Code Online (Sandbox Code Playgroud)

如果您的程序包含使用TensorFlow协议缓冲区生成的Java代码(打包在org.tensorflow:proto工件中),您也可以纯粹在Java中提取此信息,例如:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig =
      MetaGraphDef.parseFrom(model.metaGraphDef())
          .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
Run Code Online (Sandbox Code Playgroud)

您将必须添加:

import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
Run Code Online (Sandbox Code Playgroud)

由于Java API和save-model-format有点新,因此文档中还有很大的改进空间。

希望能有所帮助。