TensorFlow自定义估算器在训练后调用评估时卡住

Sum*_*ron 3 python tensorflow

我根据他们的指南在TensorFlow(v1.10)中制作了一个自定义估算器(参见本合作).

我训练玩具模型:

tf.estimator.train_and_evaluate(est, train_spec, eval_spec)
Run Code Online (Sandbox Code Playgroud)

然后,使用一些测试集数据,尝试使用以下方法评估模型:

test_fn = lambda: input_fn(DATASET['test'], run_params)
test_res = est.evaluate(input_fn=test_fn)
Run Code Online (Sandbox Code Playgroud)

(其中train_fnvalid_fn功能相同test_fn,例如足以tf.estimator.train_and_evaluate工作).

我希望会发生一些事情,但这是我得到的:

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-11-09-13:38:44
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./test/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Run Code Online (Sandbox Code Playgroud)

然后它就会永远运行.

怎么会?

Oli*_*ene 8

这是因为您无限期地重复数据集:

# In input_fn
dataset = dataset.repeat().batch(batch_size)
Run Code Online (Sandbox Code Playgroud)

默认情况下,estimator.evaluate()会一直运行,直到input_fn引发输入结束异常.因为您无限期地重复测试数据集,所以它永远不会引发异常并继续运行.

您可以在测试时删除重复,也可以使用原始'eval_spec'中使用的'steps'参数运行给定步数的评估.