使用 tf.data.Dataset.map() 混合增强样本和原始样本

Adi*_*tya 0 python tensorflow

据我从张量流文档中了解到,地图用于基于函数 parse_function_wrapper 修改图像。

dataset = dataset.map(parse_function_wrapper,
                  num_parallel_calls=4)
dataset = dataset.batch(32)
Run Code Online (Sandbox Code Playgroud)

现在数据集将只有增强图像而没有原始图像。所以我的疑问是我们需要使用原始数据和增强数据来训练我们的模型。谁能告诉我如何用原始数据进行训练?

GPh*_*ilo 5

我看到两个简单的解决方案:

1)保留原始数据集和增强数据集,然后对其进行压缩、flat_map 和 shuffle:

augmented = dataset.map(parse_function_wrapper,
                  num_parallel_calls=4)
mixed_dataset = (tf.data.Dataset.zip([dataset, augmented])
                 .flat_map(lambda x: x)
                 .shuffle(BUFFER_SIZE)) # use an appropriate buffer size
Run Code Online (Sandbox Code Playgroud)

2)parse_function_wrapper通过应用概率增强p < 1并以概率返回未修改的输入来实现随机1-p。再加上重复数据集,可以获得与之前的解决方案类似的效果,但逻辑上更容易理解。此外,通过这种方式,您可以更好地控制训练数据集中增强样本与原始样本的比率,因为您可以明确设置“混合”数据集应该是增强数据的百分比。