如何在 tensorflow tfrecords 中增加数据?

dee*_*ndu 6 deep-learning tensorflow tensorflow-slim tensorflow-datasets tensorflow-estimator

我使用 tfrecords 存储我的数据,我使用DatasetAPI作为张量读取它们,然后我使用EstimatorAPI 进行训练。现在,我想对数据集中的每个项目进行在线数据增强,但尝试了一段时间后,我找不到办法做到这一点。我想要随机翻转,随机旋转和其他操纵器。

我正在按照教程中给出的说明使用自定义估计器,这是我的 CNN,但我不确定数据增强步骤发生在哪里。

Oli*_*rot 5

使用 TFRecords 不会阻止您进行数据扩充。

按照您在评论中链接的教程,大致会发生以下情况:

  • 您从 TFRecords 文件创建数据集,并解析该文件以获取一个image和一个label
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
Run Code Online (Sandbox Code Playgroud)
  • 您现在可以应用新的预处理功能在训练期间进行一些数据扩充
# Only do it when we are training
if train:
    dataset = dataset.map(train_preprocess)
Run Code Online (Sandbox Code Playgroud)
  • train_preprocess函数可以是这样的:
def train_preprocess(image, label):
    flip_image = tf.image.random_flip_left_right(image)
    # Other transformations...
    return flip_image, label
Run Code Online (Sandbox Code Playgroud)