使用整个 MNIST 数据集(60000 张图像)训练 tensorflow 需要多少次迭代?

Swa*_*nil 5 python mnist tensorflow

MNIST 集由 60,000 张图像组成,用于训练集。在训练我的 Tensorflow 时,我想运行训练步骤以使用整个训练集训练模型。Tensorflow 网站上的深度学习示例使用 20,000 次迭代,批量大小为 50(总计 1,000,000 个批次)。当我尝试超过 30,000 次迭代时,我的数字预测失败(预测所有手写数字为 0)。我的问题是,在批量大小为 50 的情况下,我应该使用多少次迭代来训练带有整个 MNIST 集的 tensorflow 模型?

self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
for i in range(FLAGS.training_steps):
    batch = self.mnist.train.next_batch(50)
    self.train_step.run(feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5})
    if (i+1)%1000 == 0:
       saver.save(self.sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step = i)
Run Code Online (Sandbox Code Playgroud)

Yao*_*ang 2

我认为这取决于您的停止标准。当损失没有改善时,您可以停止训练,或者您可以拥有一个验证数据集,并在验证准确性不再改善时停止训练。