我已经使用我自己的 Estimator/Experiment 之类的代码一年多了,但我想最终加入 Dataset+Estimator 的潮流。
我想做如下事情:
for _ in range(N):
estimator.train(train_input_fn, steps=1000)
estimator.evaluate(validation_input_fn)
Run Code Online (Sandbox Code Playgroud)
其中train_input_fn创建一个tf.data.Dataset永远循环遍历训练集,并validation_input_fn创建一个tf.data.Dataset执行一次验证集的传递。
是否train()保持train_input_fn跨调用的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?
As I mentioned in my comment above, it looks like it does not save state across calls to estimator.train().
A solution that I am going with, and possibly the intended method, is to pass evaluation listeners to estimator.train(). For example,
class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self.estimator = estimator
self.input_fn = input_fn
def after_save(self, session, global_step):
self.estimator.evaluate(self.input_fn)
estimator.train(
input_fn=lambda:_train_input_fn(...),
max_steps=N,
saving_listeners=[
EvalCheckpointSaverListener(
estimator,
lambda:_eval_input_fn(...),
),
],
)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3055 次 |
| 最近记录: |