如何取消批处理 Tensorflow 2.0 数据集

Tom*_*t45 3 python machine-learning keras tensorflow tensorflow-datasets

我有一个使用以下代码创建的数据集tf.data.Dataset

dataset = Dataset.from_tensor_slices(corona_new)
dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))
dataset = dataset.map(lambda x: tf.transpose(x))

for i in dataset:
    print(i.numpy())
    break
Run Code Online (Sandbox Code Playgroud)

当我运行它时,我得到以下输出(这是一批的示例):

[[  0. 125. 111. 232. 164. 134. 235. 190.] 
 [  0.  14.  16.   7.   9.   7.   6.   8.]
 [  0. 132. 199. 158. 148. 141. 179. 174.]
 [  0.   0.   0.   2.   0.   2.   1.   2.]
 [  0.   0.   0.   0.   3.   5.   0.   0.]]
Run Code Online (Sandbox Code Playgroud)

我怎样才能取消它们?

Tom*_*t45 6

找到了我的解决方案。

在 TensorFlow 2.0 中,您可以tf.data.Dataset通过调用该.unbatch()函数来取消批处理。

例子:dataset.unbatch()