TensorFlow教程batch_xs中的next_batch,batch_ys = mnist.train.next_batch(100)来自哪里?

Dan*_*Dan 14 python numpy tensorflow

我正在尝试TensorFlow教程并且不明白这行中的next_batch来自哪里?

 batch_xs, batch_ys = mnist.train.next_batch(100)
Run Code Online (Sandbox Code Playgroud)

我在看

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Run Code Online (Sandbox Code Playgroud)

并没有看到next_batch那里.

现在,在我自己的代码中尝试next_batch时,我得到了

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'
Run Code Online (Sandbox Code Playgroud)

所以我想了解next_batch来自哪里?

Nic*_*ker 19

next_batchDataSet类的方法(有关类中的内容的更多信息,请参阅https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py).

加载mnist数据并将其分配给变量mnist时:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Run Code Online (Sandbox Code Playgroud)

看看上课mnist.train.你可以输入以下内容来查看:

print mnist.train.__class__
Run Code Online (Sandbox Code Playgroud)

你会看到以下内容:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>
Run Code Online (Sandbox Code Playgroud)

因为mnist.train是类的实例DataSet,所以可以使用类的函数next_batch.有关类的更多信息,请查看文档.


Dar*_*ent 8

在查看tensorflow存储库之后,它似乎源于此处:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

但是,如果您希望在自己的代码中实现它(对于您自己的数据集),那么在我自己的数据集对象中自己编写它可能要简单得多.据我了解,这是一种对整个数据集进行混洗的方法,并从混洗数据集中返回$ mini_batch_size个样本数.

这是一些伪代码:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]