mtn*_*gld 7 tensorflow tensorflow-datasets tensorflow-estimator
我tf.estimator.Estimator用来开发我的模型,
我写了一个model_fn并训练了50,000次迭代,现在我想对我做一个小改动model_fn,例如添加一个新图层.
我不想从头开始训练,我想从50,000检查点恢复所有旧变量,并从这一点继续训练.当我尝试这样做时,我得到了一个NotFoundError
怎么办tf.estimator.Estimator呢?
TL; DR从前一个检查点加载变量的最简单方法是使用该函数tf.train.init_from_checkpoint().只需model_fn在Estimator 内部调用此函数,就会覆盖相应变量的初始值设定项.
更详细地说,假设您已经在MNIST上训练了第一个带有两个隐藏层的模型,命名为model_fn_1.权重保存在目录中mnist_1.
def model_fn_1(features, labels, mode):
images = features['image']
h1 = tf.layers.dense(images, 100, activation=tf.nn.relu, name="h1")
h2 = tf.layers.dense(h1, 100, activation=tf.nn.relu, name="h2")
logits = tf.layers.dense(h2, 10, name="logits")
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
# Estimator 1: two hidden layers
estimator_1 = tf.estimator.Estimator(model_fn_1, model_dir='mnist_1')
estimator_1.train(input_fn=train_input_fn, steps=1000)
Run Code Online (Sandbox Code Playgroud)
现在我们要训练一个model_fn_2有三个隐藏层的新模型.我们要加载的权重前两个隐藏层h1和h2.我们tf.train.init_from_checkpoint()用来做这个:
def model_fn_2(features, labels, mode, params):
images = features['image']
h1 = tf.layers.dense(images, 100, activation=tf.nn.relu, name="h1")
h2 = tf.layers.dense(h1, 100, activation=tf.nn.relu, name="h2")
h3 = tf.layers.dense(h2, 100, activation=tf.nn.relu, name="h3")
assignment_map = {
'h1/': 'h1/',
'h2/': 'h2/'
}
tf.train.init_from_checkpoint('mnist_1', assignment_map)
logits = tf.layers.dense(h3, 10, name="logits")
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
# Estimator 2: three hidden layers
estimator_2 = tf.estimator.Estimator(model_fn_2, model_dir='mnist_2')
estimator_2.train(input_fn=train_input_fn, steps=1000)
Run Code Online (Sandbox Code Playgroud)
在assignment_map将加载从范围的每一个变量h1/在检查站进入新的范围h1/,并用相同的h2/.不要忘记/最后让TensorFlow知道它是一个可变范围.
我找不到使用预先制作的估算器来完成这项工作的方法,因为你无法改变它们model_fn.
| 归档时间: |
|
| 查看次数: |
1933 次 |
| 最近记录: |