我正在尝试学习Tensorflow,我想使用Tensorflow的cifar10教程框架并在mnist(结合两个教程)之上进行训练.
在cifar10.py的火车方法中:
cifar10.train(total_loss, global_step):
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
100,
0.1,
staircase=True)
tf.scalar_summary('learning_rate', lr)
tf.scalar_summary('global_step', global_step)
Run Code Online (Sandbox Code Playgroud)
global_step被传递初始化并传入,并且global_step确实在一步增加1并且学习速率正确地衰减源代码可以在tensorflow的cifar10教程中找到.
但是,当我尝试为修改后的mnist.py的列车方法代码执行相同操作时:
mnist.training(loss, batch_size, global_step):
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(0.1,
global_step,
100,
0.1,
staircase=True)
tf.scalar_summary('learning_rate1', lr)
tf.scalar_summary('global_step1', global_step)
# Create the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(lr)
# Create a variable to track the global step.
global_step = tf.Variable(0, name='global_step', trainable=False)
# Use the optimizer to apply the gradients that minimize the loss
# (and also increment the global step counter) as a single training step.
train_op = optimizer.minimize(loss, global_step=global_step)
tf.scalar_summary('global_step2', global_step)
tf.scalar_summary('learning_rate2', lr)
return train_op
Run Code Online (Sandbox Code Playgroud)
全局步骤(在cifar10和我的mnist文件中)初始化为:
with tf.Graph().as_default():
global_step = tf.Variable(0, trainable=False)
...
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = mnist10.training(loss, batch_size=100,
global_step=global_step)
Run Code Online (Sandbox Code Playgroud)
在这里,我记录两次全局步骤和学习率的标量概述:learning_rate1和learning_rate2都是相同的并且恒定为0.1(初始学习率).global_step1在2000步中也保持为0.global_step2每步线性增加1.
更详细的代码结构可以在以下网址找到:https: //bitbucket.org/jackywang529/tesorflow-sandbox/src
这是相当混乱,我为什么会是这样(在我global_step的情况下,由于我以为一切都建立象征,所以一旦程序开始运行全局步骤应该是递增不管我在哪里写的总结)和我认为这就是我学习率不变的原因.当然,我可能会犯一些简单的错误,很乐意得到帮助/解释.
您正在传递一个名为global_stepto 的参数mnist.training,并且还会创建一个名为global_stepin 的变量mnist.training.用于跟踪的exponential_decay变量是传入的变量,但实际递增的optimizer.minimize变量(通过传递)是新创建的变量.只需删除以下语句即可mnist.training:
global_step = tf.Variable(0, name='global_step', trainable=False)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
5728 次 |
| 最近记录: |