TensorFlow:dataset.train.next_batch是如何定义的?

Eda*_*ame 13 neural-network python-3.x autoencoder tensorflow

我正在尝试学习TensorFlow并在以下网址学习示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后,我在下面的代码中有一些问题:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))
Run Code Online (Sandbox Code Playgroud)

由于mnist只是一个数据集,究竟是什么mnist.train.next_batch意思呢?怎么dataset.train.next_batch定义?

谢谢!

mrr*_*rry 25

mnist对象从模块中定义的read_data_sets()函数返回tf.contrib.learn.该mnist.train.next_batch(batch_size)方法被实现在这里,它返回两个阵列,其中,所述第一表示一批的元组batch_sizeMNIST图像,并且所述第二表示一批batch-size对应于这些图像的标签.

图像作为2-D NumPy大小的数组返回[batch_size, 784](因为MNIST图像中有784个像素),并且标签作为1-D NumPy大小的数组返回[batch_size](如果read_data_sets()被调用one_hot=False)或2- D NumPy大小数组[batch_size, 10](如果read_data_sets()被调用one_hot=True).

  • 值得一提的是[next_batch](https://github.com/tensorflow/tensorflow/blob/7c36309c37b04843030664cdc64aca2bb7d6ecaa/tensorflow/contrib/learn/python/learn/datasets/mnist.py#L160)经过全部的重新洗牌后的例子.他们每个时代.您可以通过`DataSet._index_in_epoch`跟踪您在这个时代的位置,例如`mnist.train._index_in_epoch` (10认同)