oms*_*gar 8 python tensorflow tensorflow-datasets
我理解Dataset API是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小.我正在谈论存储在文本文件或tfRecord文件中的大型数据语料库.通常使用tf.data.TextLineDataset或类似的东西来读取这些文件.找到使用的数据集加载大小是微不足道的tf.data.Dataset.from_tensor_slices.
我问数据集大小的原因如下:假设我的数据集大小为1000个元素.批量大小= 50个元素.然后训练步骤/批次(假设1个纪元)= 20.在这20个步骤中,我想将我的学习率从0.1到0.01指数衰减为
tf.train.exponential_decay(
learning_rate = 0.1,
global_step = global_step,
decay_steps = 20,
decay_rate = 0.1,
staircase=False,
name=None
)
Run Code Online (Sandbox Code Playgroud)
在上面的代码中,我有"和"想要设置decay_steps = number of steps/batches per epoch = num_elements/batch_size.仅当预先知道数据集中的元素数量时,才能计算此值.
另一个原因预先知道尺寸是将数据拆分为使用训练集和测试集tf.data.Dataset.take(),tf.data.Dataset.skip()方法.
PS:我不是在寻找蛮力方法,例如迭代整个数据集并更新计数器来计算元素数量或放置非常大的批量大小,然后查找结果数据集的大小等.
您可以选择手动指定数据集的大小吗?
我如何加载数据:
sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")
sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
num_samples = tf.size(sample_ids)
data = tf.data.Dataset.from_tensor_slices(sample_ids)
# "load" data by id:
# return (id, data) for each id
data = data.map(
lambda id: (id, some_load_op(id))
)
Run Code Online (Sandbox Code Playgroud)
sample_ids在这里,您可以通过使用占位符初始化一次来指定所有示例 ID 。
您的示例 ID 可以是文件路径或简单数字 ( np.arange(num_elems))
然后可以在 中获得元素的数量num_samples。
| 归档时间: |
|
| 查看次数: |
2094 次 |
| 最近记录: |