如何合并两个(或更多)TensorFlow 数据集?

Vah*_*ili 5 python tensorflow tensorflow-datasets tensorflow2.0

我已获取具有 3 个分区的 CelebA 数据集,如下所示

>>> celeba_bldr = tfds.builder('celeb_a')
>>> datasets = celeba_bldr.as_dataset()
>>> datasets.keys()
dict_keys(['test', 'train', 'validation'])

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']
Run Code Online (Sandbox Code Playgroud)

现在,我想将它们全部合并到一个数据集中。例如,我需要将训练和验证结合在一起,或者可能将它们全部合并在一起,然后根据我自己的不同主题不相交标准将它们分开。有办法做到这一点吗?

我在文档中找不到任何选项来执行此操作https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset

Add*_*ddy 10

查看您链接的文档,数据集似乎有concatenate方法,所以我假设您可以获得一个联合数据集:

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

ds = ds_train.concatenate(ds_test).concatenate(ds_valid)
Run Code Online (Sandbox Code Playgroud)

请参阅: https: //www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#concatenate


Ima*_*deh 8

我还想提一下,如果您需要连接多个数据集(例如,数据集列表),您可以以更有效的方式进行:

ds_l = [ds_1, ds_2, ds_3] # list of `Dataset` objects

# 1. create dataset where each element is a `tf.data.Dataset` object
ds = tf.data.Dataset.from_tensor_slices(ds_l)

# 2. extract all elements from datasets and concat them into one dataset
concat_ds = ds.interleave(
    lambda x: x,
    cycle_length=1,
    num_parallel_calls=tf.data.AUTOTUNE,
)

Run Code Online (Sandbox Code Playgroud)

您也可以使用flat_map(),但我认为使用interleave()并行调用会更快。总的来说interleave是一个概括flat_map