我已经使用我自己的 Estimator/Experiment 之类的代码一年多了,但我想最终加入 Dataset+Estimator 的潮流。
我想做如下事情:
for _ in range(N):
  estimator.train(train_input_fn, steps=1000)
  estimator.evaluate(validation_input_fn)
Run Code Online (Sandbox Code Playgroud)
其中train_input_fn创建一个tf.data.Dataset永远循环遍历训练集,并validation_input_fn创建一个tf.data.Dataset执行一次验证集的传递。
是否train()保持train_input_fn跨调用的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?
As I mentioned in my comment above, it looks like it does not save state across calls to estimator.train().
A solution that I am going with, and possibly the intended method, is to pass evaluation listeners to estimator.train(). For example,
class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
  def __init__(self, estimator, input_fn):
    self.estimator = estimator
    self.input_fn = input_fn
  def after_save(self, session, global_step):
    self.estimator.evaluate(self.input_fn)
estimator.train(
  input_fn=lambda:_train_input_fn(...),
  max_steps=N,
  saving_listeners=[
    EvalCheckpointSaverListener(
      estimator,
      lambda:_eval_input_fn(...), 
    ),
  ],
)
Run Code Online (Sandbox Code Playgroud)
        |   归档时间:  |  
           
  |  
        
|   查看次数:  |  
           3055 次  |  
        
|   最近记录:  |