spl*_*eez 6 python tensorflow tensorflow-datasets
我有一个这样的数据集:
a = tf.data.Dataset.range(1, 16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b))
list(zipped.as_numpy_iterator())
# output:
[(0, 16),
(1, 17),
(2, 18),
(3, 19),
(4, 20),
(5, 21),
(6, 22),
(7, 23),
(8, 24),
(9, 25),
(10, 26),
(11, 27),
(12, 28),
(13, 29),
(14, 30),
(15, 31)]
Run Code Online (Sandbox Code Playgroud)
当我应用batch(4)它时,预期结果是一个批次数组,其中每个批次包含四个元组:
[[(0, 16), (1, 17), (2, 18), (3, 19)],
[(4, 20), (5, 21), (6, 22), (7, 23)],
[(9, 24), (10, 25), (10, 26), (11, 27)],
[(12, 28), (13, 29), (14, 30), (15, 31)]]
Run Code Online (Sandbox Code Playgroud)
但这是我收到的:
batched = zipped.batch(4)
list(batched.as_numpy_iterator())
# Output:
[(array([0, 1, 2, 3]), array([16, 17, 18, 19])),
(array([4, 5, 6, 7]), array([20, 21, 22, 23])),
(array([ 8, 9, 10, 11]), array([24, 25, 26, 27])),
(array([12, 13, 14, 15]), array([28, 29, 30, 31]))]
Run Code Online (Sandbox Code Playgroud)
我正在遵循本教程,他执行相同的步骤,但以某种方式获得正确的输出。
更新:根据文档,这是预期的行为:
结果元素的组件将有一个额外的外部维度,即batch_size
但这没有任何意义。据我理解,数据集是数据的列表。这些数据的形状并不重要,当我们对其进行批处理时,我们会将元素[无论其形状是什么]组合成批次,因此它应该始终将新维度插入到第二个位置((length, a, b, c)-> (length', batch_size, a, b, c))。
batch()所以我的问题是:我想知道这样实施的目的是什么?还有什么替代方案可以实现我所描述的功能呢?
您可以尝试做的一件事是这样的:
import tensorflow as tf
a = tf.data.Dataset.range(16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.transpose([x, y]))
list(zipped.as_numpy_iterator())
Run Code Online (Sandbox Code Playgroud)
[array([[ 0, 16],
[ 1, 17],
[ 2, 18],
[ 3, 19]]),
array([[ 4, 20],
[ 5, 21],
[ 6, 22],
[ 7, 23]]),
array([[ 8, 24],
[ 9, 25],
[10, 26],
[11, 27]]),
array([[12, 28],
[13, 29],
[14, 30],
[15, 31]])]
Run Code Online (Sandbox Code Playgroud)
但它们仍然不是元组。或者:
zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.unstack(tf.transpose([x, y]), num = 4))
Run Code Online (Sandbox Code Playgroud)
[(array([ 0, 16]), array([ 1, 17]), array([ 2, 18]), array([ 3, 19])), (array([ 4, 20]), array([ 5, 21]), array([ 6, 22]), array([ 7, 23])), (array([ 8, 24]), array([ 9, 25]), array([10, 26]), array([11, 27])), (array([12, 28]), array([13, 29]), array([14, 30]), array([15, 31]))]
Run Code Online (Sandbox Code Playgroud)