使用tf.estimator.Estimator加载检查点和微调

jra*_*ary 6 tensorflow

我们正在尝试将旧的培训代码转换为更符合tf.estimator.Estimator的代码。在初始代码中,我们微调了目标数据集的原始模型。结合使用Variables_to_restoreinit_fnMonitoredTrainingSession进行组合,在进行训练之前仅从检查点加载一些图层。如何使用tf.estimator.Estimator方法实现这种重量负载?

use*_*804 5

你有两个选择,第一个更简单:

1-tf.train.init_from_checkpoint在您的model_fn

2-model_fn返回一个EstimatorSpec. 您可以通过设置脚手架EstimatorSpec