如何使用带有嵌套形状的 tf.data.Dataset.padded_batch?

fez*_*zik 5 python tensorflow tensorflow-datasets

我正在为每个元素构建一个具有 [batch,width,heigh,3] 和 [batch,class] 形状张量的数据集。为简单起见,假设 class = 5。

你给什么形状的dataset.padded_batch(1000,shape)图像沿着宽度/高度/3轴填充?

我尝试了以下方法:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])??)
Run Code Online (Sandbox Code Playgroud)

每个引发 TypeError

文档状态:

padded_shapes:tf.TensorShape 或 tf.int64 向量张量类对象的嵌套结构,表示在批处理之前每个输入元素的相应组件应填充到的形状。任何未知的维度(例如 tf.TensorShape 中的 tf.Dimension(None) 或类似张量的对象中的 -1)将被填充到每个批次中该维度的最大大小。

相关代码:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
Run Code Online (Sandbox Code Playgroud)

fez*_*zik 8

感谢 mrry 找到解决方案。原来 from_generator 中的类型必须与条目中的张量数相匹配。

新代码:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
Run Code Online (Sandbox Code Playgroud)