如何从张量流数据集中取回批量大小?

Mil*_*uss 5 issue-tracking tensorflow tensorflow-datasets

推荐使用tensorflow数据集作为输入管道,可以设置如下:

# Specify dataset
dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset  = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset  = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()
Run Code Online (Sandbox Code Playgroud)

我应该能够获得批量大小(从数据集本身或从它创建的迭代器,即两者iteratornext_batch)。也许有人想知道数据集或其迭代器中有多少批次。或者已经调用了多少批次以及迭代器中剩余多少批次?人们可能还想一次获取特定元素,甚至整个数据集。

我无法在 tensorflow 文档中找到任何内容。这可能吗?如果没有,有谁知道这是否已被请求作为 tensorflow GitHub 上的问题?

guo*_*rui 1

尝试这个

import tensorflow as tf
import numpy as np

features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(np.shape(sess.run(batch_data)[0])[0])
Run Code Online (Sandbox Code Playgroud)
你会看到 在此输入图像描述