Mil*_*uss 5 issue-tracking tensorflow tensorflow-datasets
推荐使用tensorflow数据集作为输入管道,可以设置如下:
# Specify dataset
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()
Run Code Online (Sandbox Code Playgroud)
我应该能够获得批量大小(从数据集本身或从它创建的迭代器,即两者iterator和next_batch)。也许有人想知道数据集或其迭代器中有多少批次。或者已经调用了多少批次以及迭代器中剩余多少批次?人们可能还想一次获取特定元素,甚至整个数据集。
我无法在 tensorflow 文档中找到任何内容。这可能吗?如果没有,有谁知道这是否已被请求作为 tensorflow GitHub 上的问题?
尝试这个
import tensorflow as tf
import numpy as np
features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
print(np.shape(sess.run(batch_data)[0])[0])Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3504 次 |
| 最近记录: |