tf.dataset.Dataset 上的数据增强

Luk*_*sen 6 python keras tensorflow

为了使用 Google Colabs TPU,我需要一个tf.dataset.Dataset. 那么如何在这样的数据集上使用数据增强呢?

更具体地说,到目前为止我的代码是:

def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0

    label = tf.one_hot(label,10)

    return image, label

  train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
  test_dataset = mnist_test.map(scale).batch(batch_size)

  return train_dataset, test_dataset
Run Code Online (Sandbox Code Playgroud)

这是输入到这个中的:

# TPU Strategy ...
with strategy.scope():
  model = create_model()
  model.compile(loss="categorical_crossentropy",
                optimizer="adam",
                metrics=["acc"])

train_dataset, test_dataset = get_dataset()

model.fit(train_dataset,
          epochs=20,
          verbose=1,
          validation_data=test_dataset)
Run Code Online (Sandbox Code Playgroud)

那么,我如何在这里使用数据增强呢?据我所知,我不能使用 tf.keras ImageDataGenerator,对吧?

我尝试了以下方法,但没有成功。

data_generator = ...

model.fit_generator(data_generator.flow(train_dataset, batch_size=32),
                    steps_per_epoch=len(train_dataset) / 32, epochs=20)
Run Code Online (Sandbox Code Playgroud)

这并不奇怪,因为通常 train_x 和 train_y 作为两个参数提供给流函数,而不是“打包”到一个中tf.dataset.Dataset

小智 8

您可以使用tf.image函数。该tf.image模块包含用于图像处理的各种功能。

例如:

您可以在您的 function 中添加以下功能def get_dataset

  • 将每个图像转换为该范围tf.float64内的图像0-1
  • cache()结果,因为这些结果可以在每次之后重复使用repeat
  • 使用 随机翻转每个图像random_flip_left_right
  • 使用 随机改变图像的对比度random_contrast
  • 图像数量增加两倍,repeat重复所有步骤。

代码 -

mnist_train = mnist_train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache(
).map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
    1000
).
batch(
    batch_size
).repeat(2)
Run Code Online (Sandbox Code Playgroud)

同样,您可以使用其他功能,例如random_flip_up_downrandom_crop函数来随机垂直翻转图像(上下颠倒)和随机将张量裁剪为给定大小。


你的get_dataset函数将如下所示 -

def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  train_dataset = mnist_train.map(
               lambda image, label: (tf.image.convert_image_dtype(image, tf.float32),label)
              ).cache(
              ).map(
                    lambda image, label: (tf.image.random_flip_left_right(image), label)
              ).map(
                    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
              ).shuffle(
                    1000
              ).batch(
                    batch_size
              ).repeat(2)

  test_dataset = mnist_test.map(scale).batch(batch_size)

  return train_dataset, test_dataset
Run Code Online (Sandbox Code Playgroud)

添加 @Andrew H 建议的链接,该链接提供了也使用数据集的数据增强的端到端示例mnist

希望这能回答您的问题。快乐学习。