如何使用checkpoint的tf.estimator.Estimator进行预测?

Rob*_*oob 8 python machine-learning computer-vision python-3.x tensorflow

我刚训练CNN识别具有张量流的太阳黑子.我的模型与此基本相同.问题是我无法找到关于如何使用训练阶段生成的检查点进行预测的明确解释.

尝试使用标准还原方法:

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')
Run Code Online (Sandbox Code Playgroud)

但后来我无法弄清楚如何运行它.
尝试使用tf.estimator.Estimator.predict()这样:

# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="./model")

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
    tensors=tensors_to_log, every_n_iter=50)

# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": pred_data},
    shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)
Run Code Online (Sandbox Code Playgroud)

但它的作用是吐出来的<generator object Estimator.predict at 0x10dda6bf8>.如果我使用相同的代码,但tf.estimator.Estimator.evaluate()它的作用就像魅力(重新加载模型,执行评估并将其发送到TensorBoard).

我知道有很多类似的问题,但我真的找不到对我有用的方法.

小智 8

sunspot_classifier.predict(input_fn=pred_input_fn)返回生成器.pred_results生成器对象也是如此.要从中获取价值,您需要迭代它next(pred_results)

解决方案是 print(next(pred_results))