Dan*_*ani 21 dataset tensorflow
有谁知道如何将Tensorflow中数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?
apa*_*kin 28
假设你有类型的all_dataset变量tf.data.Dataset:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
Run Code Online (Sandbox Code Playgroud)
测试数据集现在有前1000个元素,其余的用于训练.
Pat*_*ick 16
这里的大多数答案都使用take()and skip(),这需要事先知道数据集的大小。这并不总是可能的,或者很难/密集地确定。
相反,您可以做的是从本质上对数据集进行切片,以便每 N 条记录 1 个成为验证记录。
为此,让我们从一个简单的 0-9 数据集开始:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Run Code Online (Sandbox Code Playgroud)
现在对于我们的示例,我们将对其进行切片,以便我们有 3/1 的训练/验证拆分。这意味着 3 条记录将进行训练,然后 1 条记录进行验证,然后重复。
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
Run Code Online (Sandbox Code Playgroud)
所以第一个dataset.window(split, split + 1)说要抓取split第(3)个元素,然后推进split + 1元素,然后重复。这+ 1有效地跳过了我们将在验证数据集中使用的 1 个元素。
这flat_map(lambda ds: ds)是因为window()批量返回结果,这是我们不想要的。所以我们把它压平。
然后对于验证数据,我们首先skip(split)跳过在第一个训练窗口中抓取的第一个元素split数(3),因此我们在第 4 个元素上开始迭代。所述window(1, split + 1)然后抓起1个元件,预付款split + 1 (4) ,并重复。
关于嵌套数据集的注意事项:
上面的示例适用于简单的数据集,但flat_map()如果数据集是嵌套的,则会产生错误。为了解决这个问题,你可以flat_map()用一个可以处理简单和嵌套数据集的更复杂的版本来替换:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
Run Code Online (Sandbox Code Playgroud)
ted*_*ted 13
您可以使用Dataset.take()和Dataset.skip():
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
Run Code Online (Sandbox Code Playgroud)
为了更笼统,我举了一个使用70/15/15 train / val / test split的示例,但是如果您不需要测试或val集,则只需忽略最后两行。
采取:
从此数据集中创建一个最多包含count个元素的数据集。
跳过:
创建一个数据集,该数据集从该数据集中跳过计数元素。
您可能还需要调查Dataset.shard():
创建一个仅包含此数据集的1 / num_shards的数据集。
免责声明我就这个问题回答绊倒后,这一个,所以我想我会传播爱
@ted 的回答会造成一些重叠。尝试这个。
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
Run Code Online (Sandbox Code Playgroud)
使用下面的代码进行测试。
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
Run Code Online (Sandbox Code Playgroud)
您可以使用shard:
dataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
Run Code Online (Sandbox Code Playgroud)
请参阅: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
即将推出的 TensorFlow 2.10.0 将有一个tf.keras.utils.split_dataset function,请参阅rc3 发行说明:
添加了将对象或数组列表/元组
tf.keras.utils.split_dataset拆分为两个对象的实用程序(例如训练/测试)。DatasetDataset
小智 4
现在 Tensorflow 不包含任何用于此目的的工具。
您可以使用sklearn.model_selection.train_test_split生成训练/评估/测试数据集,然后tf.data.Dataset分别创建。
| 归档时间: |
|
| 查看次数: |
9319 次 |
| 最近记录: |