如何获取 tf.data.dataset 的形状?

cao*_*gyu 13 machine-learning deep-learning tensorflow tensorflow-datasets

我知道数据集有 output_shapes,但它显示如下:

data_set: DatasetV1Adapter 形状: {item_id_hist: (?, ?), tags: (?, ?), client_platform: (?,), entry: (?,), item_id: (?,), label: (?,),模式:(?,),时间:(?,),user_id:(?,)},类型:{item_id_hist:tf.int64,标签:tf.int64,client_platform:tf.string,入口:tf.string,item_id :tf.int64,标签:tf.int64,模式:tf.int64,时间:tf.int64,user_id:tf.int64}

我怎样才能得到我的数据总数?

Ste*_*t_R 14

如果长度已知,您可以调用:

tf.data.experimental.cardinality(dataset)
Run Code Online (Sandbox Code Playgroud)

但是如果这失败了,重要的是要知道 TensorFlow Dataset(通常)是惰性求值的,所以这意味着在一般情况下,我们可能需要遍历每条记录,然后才能找到数据集的长度。

例如,假设您启用了 Eager Execution 并且它是一个适合内存的小型“玩具”数据集,您可以将enumerate其放入一个新列表并获取最后一个索引(然后加 1,因为列表是零索引的):

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1
Run Code Online (Sandbox Code Playgroud)

当然,这充其量是低效的,对于大型数据集,将完全失败,因为一切都需要适合列表的内存。在这种情况下,除了遍历保持手动计数的记录之外,我看不到任何其他选择。