cbo*_*nho 5 python tensorflow tensorflow-datasets
想象一下我有:
我想从两个数据集中提取批次并将它们连接起来,以便获得大小为 3 的批次,其中:
如果某些数据集首先被清空,我还想读取最后一批。在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。
我怎样才能做到这一点?我在这里看到了答案:Tensorflow 如何生成不平衡的组合数据集
这是一个很好的尝试,但如果其中一个数据集在其他数据集之前被清空,则它不起作用(因为tf.errors.OutOfRangeError
当您尝试从首先被清空的数据集中获取元素时,然后被抢先输出,而我没有得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]
我想过使用tf.contrib.data.choose_from_datasets
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)
Run Code Online (Sandbox Code Playgroud)
这种工作,但相当不雅(有unbatch和batch);此外,我对批次中的确切内容并没有很好的控制。(例如,如果 ds1 是 [7] * 7 批量大小为 2,而 ds2 是 [2, 2] 批量大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]. 但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 呢?即保持每个数据集中的元素数量固定.
还有其他更好的解决方案吗?
我的另一个想法是尝试使用tf.data.Dataset.flat_map
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
return concat(datasets)
dataset = (tf.data.Dataset
.zip((ds1, ds2))
.flat_map(_concat_and_batch)
.batch(sum(batch_sizes)))
Run Code Online (Sandbox Code Playgroud)
但它似乎不起作用..
如果您不介意在构建新数据集期间运行会话,则可以执行以下操作:
import tensorflow as tf
import numpy as np
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next()
batch2 = iter2.get_next()
sess = tf.Session()
# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
while True:
try:
b1 = sess.run(batch1)
except tf.errors.OutOfRangeError:
b1 = None
try:
b2 = sess.run(batch2)
except tf.errors.OutOfRangeError:
b2 = None
if (b1 is None) and (b2 is None):
break
elif b1 is None:
yield b2
elif b2 is None:
yield b1
else:
yield np.concatenate((b1,b2))
# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)
Run Code Online (Sandbox Code Playgroud)
请注意,如果您愿意,可以隐藏\封装上述会话(例如,在函数内部),例如:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess2 = tf.Session()
while True:
print(sess2.run(batch))
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
4606 次 |
最近记录: |