如何从tf.estimator.Estimator获取最后的global_step

mad*_*n25 3 python tensorflow tensorflow-estimator

如何global_step从完成tf.estimator.Estimator后获得最后一个train(...)?例如,可以像这样建立一个典型的基于Estimator的训练例程:n_epochs = 10 model_dir ='/ path / to / model_dir'

def model_fn(features, labels, mode, params):
    # some code to build the model
    pass

def input_fn():
    ds = tf.data.Dataset()  # obviously with specifying a data source
    # manipulate the dataset
    return ds

run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn)
    # Now I want to do something which requires to know the last global step, how to get it?
    my_custom_eval_method(global_step)
Run Code Online (Sandbox Code Playgroud)

仅该evaluate()方法返回包含global_stepas字段的字典。如何获得global_step,如果由于某种原因我无法使用或不想使用此方法?

mad*_*n25 5

只需在训练循环之前创建一个钩子即可:

class GlobalStepHook(tf.train.SessionRunHook):
    def __init__(self):
        self._global_step_tensor = None
        self.value = None

    def begin(self):
        self._global_step_tensor = tf.train.get_global_step()

    def after_run(self, run_context, run_values):
        self.value = run_context.session.run(self._global_step_tensor)

    def __str__(self):
        return str(self.value)

global_step = GlobalStepHook()
for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn, hooks=[global_step])
    # Now the global_step hook contains the latest value of global_step
    my_custom_eval_method(global_step.value)
Run Code Online (Sandbox Code Playgroud)