tra*_*veh 6 python tensorflow google-cloud-ml
我不明白如何使用TensorFlow Estimator API进行单一预测 - 我的代码导致无限循环,不断预测相同的输入.
根据文档,当input_fn引发StopIteration异常时,预测应该停止:
input_fn:返回功能的输入函数,它是Tensor或SparseTensor的字符串功能名称字典.如果它返回一个元组,则第一个项目被提取为特征.预测将继续,直到input_fn引发输入结束异常(OutOfRangeError或StopIteration).
这是我的代码中的相关部分:
classifier = tf.estimator.Estimator(model_fn=image_classifier, model_dir=output_dir,
config=training_config, params=hparams)
def make_predict_input_fn(filename):
queue = [ filename ]
def _input_fn():
if len(queue) == 0:
raise StopIteration
image = model.read_and_preprocess(queue.pop())
return {'image': image}
return _input_fn
predictions = classifier.predict(make_predict_input_fn('garden-rose-red-pink-56866.jpeg'))
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["class"]))
Run Code Online (Sandbox Code Playgroud)
我错过了什么?
这是因为 input_fn() 需要是一个生成器。将您的函数更改为(yield 而不是 return):
def make_predict_input_fn(filename):
queue = [ filename ]
def _input_fn():
if len(queue) == 0:
raise StopIteration
image = model.read_and_preprocess(queue.pop())
yield {'image': image}
return _input_fn
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
487 次 |
| 最近记录: |