将Tensorflow数据集API创建的数据集拆分为Train and Test?

Dan*_*ani 21 dataset tensorflow

有谁知道如何将Tensorflow中数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?

apa*_*kin 28

假设你有类型的all_dataset变量tf.data.Dataset:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)
Run Code Online (Sandbox Code Playgroud)

测试数据集现在有前1000个元素,其余的用于训练.

  • 正如[ted的回答](/sf/answers/3588108681/)中也提到的,添加 `all_dataset.shuffle()` 允许进行随机分割。可能会像这样在答案中添加代码注释吗?`# all_dataset = all_dataset.shuffle() # 如果你想要随机分割` (4认同)

Pat*_*ick 16

这里的大多数答案都使用take()and skip(),这需要事先知道数据集的大小。这并不总是可能的,或者很难/密集地确定。

相反,您可以做的是从本质上对数据集进行切片,以便每 N 条记录 1 个成为验证记录。

为此,让我们从一个简单的 0-9 数据集开始:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Run Code Online (Sandbox Code Playgroud)

现在对于我们的示例,我们将对其进行切片,以便我们有 3/1 的训练/验证拆分。这意味着 3 条记录将进行训练,然后 1 条记录进行验证,然后重复。

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
Run Code Online (Sandbox Code Playgroud)

所以第一个dataset.window(split, split + 1)说要抓取split(3)个元素,然后推进split + 1元素,然后重复。这+ 1有效地跳过了我们将在验证数据集中使用的 1 个元素。
flat_map(lambda ds: ds)是因为window()批量返回结果,这是我们不想要的。所以我们把它压平。

然后对于验证数据,我们首先skip(split)跳过在第一个训练窗口中抓取的第一个元素split(3),因此我们在第 4 个元素上开始迭代。所述window(1, split + 1)然后抓起1个元件,预付款split + 1 (4) ,并重复。

 

关于嵌套数据集的注意事项:
上面的示例适用于简单的数据集,但flat_map()如果数据集是嵌套的,则会产生错误。为了解决这个问题,你可以flat_map()用一个可以处理简单和嵌套数据集的更复杂的版本来替换:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
Run Code Online (Sandbox Code Playgroud)

  • 如果您有一个包含 1000 条记录的数据集,并且需要 10% 的记录进行验证,则必须在发出单个验证记录之前跳过前 900 条记录。使用此解决方案,只需跳过 9 条记录。它最终确实会跳过相同的数量,但如果您使用“dataset.prefetch()”,它可以在执行其他操作时在后台读取。区别只是节省了初始假脱机时间。 (2认同)
  • 您可能应该将*在事先不知道数据集大小的情况下*设置为粗体,或者像标题或其他东西一样,这非常重要。这确实应该是公认的答案,因为它符合“tf.data.Dataset”将数据视为无限流的前提。 (2认同)

ted*_*ted 13

您可以使用Dataset.take()Dataset.skip()

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
Run Code Online (Sandbox Code Playgroud)

为了更笼统,我举了一个使用70/15/15 train / val / test split的示例,但是如果您不需要测试或val集,则只需忽略最后两行。

采取

从此数据集中创建一个最多包含count个元素的数据集。

跳过

创建一个数据集,该数据集从该数据集中跳过计数元素。

您可能还需要调查Dataset.shard()

创建一个仅包含此数据集的1 / num_shards的数据集。


免责声明我就这个问题回答绊倒后,这一个,所以我想我会传播爱

  • 非常感谢你@ted!有没有办法对数据集进行分层划分?或者,我们如何了解训练/验证/测试分割后的类比例(假设是二元问题)?预先非常感谢! (3认同)
  • @c_student我遇到了同样的问题,我发现我错过了什么:当你洗牌时使用选项`reshuffle_each_iteration=False`,否则元素可以在train、test和val中重复 (3认同)
  • 这导致我的训练、验证和测试数据集之间存在重叠。这应该发生并且没什么大不了的吗?我认为在验证和测试数据上训练模型并不是一个好主意。 (2认同)
  • 这是非常真实的@xdola,特别是在使用 `list_files` 时,您应该使用 `shuffle=False`,然后使用 `.shuffle` 和 `reshuffle_each_iteration=False` 进行随机播放。 (2认同)

Han*_*ank 6

@ted 的回答会造成一些重叠。尝试这个。

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
Run Code Online (Sandbox Code Playgroud)

使用下面的代码进行测试。

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)
Run Code Online (Sandbox Code Playgroud)

  • 我喜欢每个人都假设你知道“full_ds_size”,但没有人解释如何找到它 (3认同)

Ben*_*Uri 5

您可以使用shard

dataset = dataset.shuffle()  # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
Run Code Online (Sandbox Code Playgroud)

请参阅: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard

  • 分片已被贬值 (6认同)
  • @vgoklani 你确定吗?我没有看到任何说它已被弃用的内容。 (5认同)

Rob*_*lak 5

即将推出的 TensorFlow 2.10.0 将有一个tf.keras.utils.split_dataset function,请参阅rc3 发行说明

添加了将对象或数组列表/元组tf.keras.utils.split_dataset拆分为两个对象的实用程序(例如训练/测试)。DatasetDataset


小智 4

现在 Tensorflow 不包含任何用于此目的的工具。
您可以使用sklearn.model_selection.train_test_split生成训练/评估/测试数据集,然后tf.data.Dataset分别创建。

  • sklearn 要求内容适合内存,而 TF Data 则不需要。 (6认同)