tf.estimator.Estimator.train() 是否保持 input_fn 状态

Mar*_*ard 3 tensorflow

我已经使用我自己的 Estimator/Experiment 之类的代码一年多了,但我想最终加入 Dataset+Estimator 的潮流。

我想做如下事情:

for _ in range(N):
  estimator.train(train_input_fn, steps=1000)
  estimator.evaluate(validation_input_fn)
Run Code Online (Sandbox Code Playgroud)

其中train_input_fn创建一个tf.data.Dataset永远循环遍历训练集,并validation_input_fn创建一个tf.data.Dataset执行一次验证集的传递。

是否train()保持train_input_fn跨调用的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?

Mar*_*ard 5

As I mentioned in my comment above, it looks like it does not save state across calls to estimator.train().

A solution that I am going with, and possibly the intended method, is to pass evaluation listeners to estimator.train(). For example,

class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
  def __init__(self, estimator, input_fn):
    self.estimator = estimator
    self.input_fn = input_fn

  def after_save(self, session, global_step):
    self.estimator.evaluate(self.input_fn)

estimator.train(
  input_fn=lambda:_train_input_fn(...),
  max_steps=N,
  saving_listeners=[
    EvalCheckpointSaverListener(
      estimator,
      lambda:_eval_input_fn(...), 
    ),
  ],
)
Run Code Online (Sandbox Code Playgroud)