Luk*_*sen 6 python keras tensorflow
为了使用 Google Colabs TPU,我需要一个tf.dataset.Dataset. 那么如何在这样的数据集上使用数据增强呢?
更具体地说,到目前为止我的代码是:
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
label = tf.one_hot(label,10)
return image, label
train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset
Run Code Online (Sandbox Code Playgroud)
这是输入到这个中的:
# TPU Strategy ...
with strategy.scope():
model = create_model()
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["acc"])
train_dataset, test_dataset = get_dataset()
model.fit(train_dataset,
epochs=20,
verbose=1,
validation_data=test_dataset)
Run Code Online (Sandbox Code Playgroud)
那么,我如何在这里使用数据增强呢?据我所知,我不能使用 tf.keras ImageDataGenerator,对吧?
我尝试了以下方法,但没有成功。
data_generator = ...
model.fit_generator(data_generator.flow(train_dataset, batch_size=32),
steps_per_epoch=len(train_dataset) / 32, epochs=20)
Run Code Online (Sandbox Code Playgroud)
这并不奇怪,因为通常 train_x 和 train_y 作为两个参数提供给流函数,而不是“打包”到一个中tf.dataset.Dataset。
小智 8
您可以使用tf.image函数。该tf.image模块包含用于图像处理的各种功能。
例如:
您可以在您的 function 中添加以下功能def get_dataset。
tf.float64内的图像0-1。cache()结果,因为这些结果可以在每次之后重复使用repeatrandom_flip_left_right。random_contrast。repeat重复所有步骤。代码 -
mnist_train = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).
batch(
batch_size
).repeat(2)
Run Code Online (Sandbox Code Playgroud)
同样,您可以使用其他功能,例如random_flip_up_down、random_crop函数来随机垂直翻转图像(上下颠倒)和随机将张量裁剪为给定大小。
你的get_dataset函数将如下所示 -
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
train_dataset = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32),label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).batch(
batch_size
).repeat(2)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset
Run Code Online (Sandbox Code Playgroud)
添加 @Andrew H 建议的链接,该链接提供了也使用数据集的数据增强的端到端示例mnist。
希望这能回答您的问题。快乐学习。
| 归档时间: |
|
| 查看次数: |
6763 次 |
| 最近记录: |