在 Google Cloud Platform 中为 Keras ML 训练读取存储在桶中的数据的理想方法?

Moo*_*dra 4 python google-cloud-platform keras tensorflow

这是我第一次尝试在云中训练模型,我正在努力解决所有的小问题。我将训练数据存储在谷歌云平台内的存储桶中 gs://test/train ,数据集大约为 100k。目前,数据根据其标签分布在不同的文件夹中。

我不知道访问数据的理想方式。通常在Keras我使用,ImageDataGeneratorflow_from_directory它自动创建一个发电机,我可以喂到我的模型。

谷歌云平台是否有诸如 Python 之类的函数?

如果不是,通过生成器访问数据的理想方式是什么,以便我可以将其提供给 Keras model.fit_generator

谢谢你。

sdc*_*cbr 5

ImageDataGenerator.flow_from_directory()当前不允许您直接从 GCS 存储桶流式传输数据。我认为你有几个选择:

1/ 将数据从 GCS 复制到用于运行脚本的 VM 本地磁盘。我想您是通过 ML Engine 或在 Compute Engine 实例上执行此操作。无论哪种方式,您都可以使用gsutilpython 云存储 API在训练脚本的开头复制数据。这里有一个缺点:这会在脚本开始时花费您一些时间,尤其是当数据集很大时。

2/ 使用时tf.keras,您可以在tf.data数据集上训练您的模型。这里的好处是 TensorFlow 的 io 实用程序允许您直接从 GCS 存储桶中读取。如果要将数据转换为 TFRecords,则可以实例化 Dataset 对象,而无需先将数据下载到本地磁盘:

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

# Fit a tf.keras model
model.fit(ds_train)
Run Code Online (Sandbox Code Playgroud)

有关TFRecord 选项的更多信息,请参阅此问题。这也适用于直接从 GCS 上的图像实例化的 Dataset 对象Dataset.from_tensor_slices,这样您就不必先以 TFRecords 格式存储数据:

def load_and_preprocess_image(path):
"""Read an image GCS path and process it into an image tensor

Args:
    path (tensor): string tensor, pointer to GCS or local image path

Returns:
    tensor: processed image tensor
"""

    image = tf.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    return image

image_paths = ['gs://my-bucket/img1.png',
               'gs://my-bucket/img2/png'...]
path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
image_ds = path_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels) # can be a list of labels    
model.fit(tf.data.Dataset.zip((images_ds, labels_ds)))
Run Code Online (Sandbox Code Playgroud)

有关更多示例,请参阅TF 网站上教程

3/ 最后,还应该可以编写自己的 python 生成器或调整源代码,ImageDataGenerator以便使用 TensorFlow io 函数读取图像。同样,这些适用于gs://路径:

import tensorflow as tf
tf.enable_eager_execution()
path = 'gs://path/to/my/image.png'
tf.image.decode_png(tf.io.read_file(path)) # this works
Run Code Online (Sandbox Code Playgroud)

另请参阅此相关问题。这可能比上面列出的选项工作得慢。