小编Ban*_*ana的帖子

Tensorflow:global_step不递增; 因此exponentialDecay不起作用

我正在尝试学习Tensorflow,我想使用Tensorflow的cifar10教程框架并在mnist(结合两个教程)之上进行训练.

在cifar10.py的火车方法中:

cifar10.train(total_loss, global_step):
  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,                        
                                  global_step,                                  
                                  100,                                          
                                  0.1,                   
                                  staircase=True)                               
  tf.scalar_summary('learning_rate', lr)                                       
  tf.scalar_summary('global_step', global_step)
Run Code Online (Sandbox Code Playgroud)

global_step被传递初始化并传入,并且global_step确实在一步增加1并且学习速率正确地衰减源代码可以在tensorflow的cifar10教程中找到.

但是,当我尝试为修改后的mnist.py的列车方法代码执行相同操作时:

mnist.training(loss, batch_size, global_step):
  # Decay the learning rate exponentially based on the number of steps.         
  lr = tf.train.exponential_decay(0.1,                                          
                                  global_step,                                  
                                  100,                                             
                                  0.1,                                             
                                  staircase=True)                                  
  tf.scalar_summary('learning_rate1', lr)                                          
  tf.scalar_summary('global_step1', global_step)                                   

  # Create the gradient descent optimizer with the given learning rate.            
  optimizer = tf.train.GradientDescentOptimizer(lr)                                
  # Create a variable to track the global step.                                    
  global_step = tf.Variable(0, name='global_step', trainable=False)                
  # Use the optimizer to apply the gradients that …
Run Code Online (Sandbox Code Playgroud)

python tensorflow

4
推荐指数
1
解决办法
5728
查看次数

标签 统计

python ×1

tensorflow ×1