Dan*_*Dan 14 python numpy tensorflow
我正在尝试TensorFlow教程并且不明白这行中的next_batch来自哪里?
batch_xs, batch_ys = mnist.train.next_batch(100)
Run Code Online (Sandbox Code Playgroud)
我在看
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Run Code Online (Sandbox Code Playgroud)
并没有看到next_batch那里.
现在,在我自己的代码中尝试next_batch时,我得到了
AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'
Run Code Online (Sandbox Code Playgroud)
所以我想了解next_batch来自哪里?
Nic*_*ker 19
next_batch是DataSet类的方法(有关类中的内容的更多信息,请参阅https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py).
加载mnist数据并将其分配给变量mnist时:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Run Code Online (Sandbox Code Playgroud)
看看上课mnist.train.你可以输入以下内容来查看:
print mnist.train.__class__
Run Code Online (Sandbox Code Playgroud)
你会看到以下内容:
<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>
Run Code Online (Sandbox Code Playgroud)
因为mnist.train是类的实例DataSet,所以可以使用类的函数next_batch.有关类的更多信息,请查看文档.
在查看tensorflow存储库之后,它似乎源于此处:
但是,如果您希望在自己的代码中实现它(对于您自己的数据集),那么在我自己的数据集对象中自己编写它可能要简单得多.据我了解,这是一种对整个数据集进行混洗的方法,并从混洗数据集中返回$ mini_batch_size个样本数.
这是一些伪代码:
shuffle data.x and data.y while retaining relation
return [data.x[:mb_n], data.y[:mb_n]]
| 归档时间: |
|
| 查看次数: |
25506 次 |
| 最近记录: |