Keras:使用flow_from_directory为训练数据拟合图像增强

Mar*_*ldt 7 python machine-learning deep-learning keras

我想在Keras中使用Image augmentation.我当前的代码如下所示:

# define image augmentations
train_datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
zca_whitening=True)

# generate image batches from directory
train_datagen.flow_from_directory(train_dir)
Run Code Online (Sandbox Code Playgroud)

当我用这个运行模型时,我收到以下错误:

"ImageDataGenerator specifies `featurewise_std_normalization`, but it hasn't been fit on any training data."
Run Code Online (Sandbox Code Playgroud)

但我没有找到有关如何使用明确的信息train_dataget.fit()一起flow_from_directory.

谢谢您的帮助.马里奥

des*_*aut 14

你是对的,文档在这方面不是很有启发性......

你需要的实际上是一个4步骤的过程:

  1. 定义数据扩充
  2. 适合增强
  3. 使用设置您的发电机 flow_from_directory()
  4. 训练你的模型 fit_generator()

以下是假设图像分类案例的必要代码:

# define data augmentation configuration
train_datagen = ImageDataGenerator(featurewise_center=True,
                                   featurewise_std_normalization=True,
                                   zca_whitening=True)

# fit the data augmentation
train_datagen.fit(x_train)

# setup generator
train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical')

# train model
model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples,
    epochs=epochs,
    validation_data=validation_generator, # optional - if used needs to be defined
    validation_steps=nb_validation_samples) 
Run Code Online (Sandbox Code Playgroud)

显然,有几个参数来定义(train_data_dir,nb_train_samples等等),但希望你的想法.

如果您还需要使用a validation_generator,如我的示例所示,这应该与您的定义相同train_generator.

更新(评论后)

第2步需要一些讨论; 这里x_train是理想情况下应该适合主存储器的实际数据.另外(文档),这一步是

仅在featurewise_center或featurewise_std_normalization或zca_whitening时才需要.

然而,在许多现实世界中,所有训练数据都适合记忆的要求显然是不现实的.在这种情况下,如何集中/规范化/白化数据本身就是一个(巨大的)子领域,可以说是存在大数据处理框架(如Spark)的主要原因.

那么,在这里做什么呢?那么,在这种情况下,下一个合乎逻辑的行动是对数据进行采样 ; 事实上,这正是社区所建议的 - 这里是Keras创建者Francois Chollet关于使用像Imagenet这样的大型数据集:

datagen.fit(X_sample) # let's say X_sample is a small-ish but statistically representative sample of your data
Run Code Online (Sandbox Code Playgroud)

还有一个关于扩展的持续公开讨论的另一个引用ImageDataGenerator(强调增加):

功能是标准化和ZCA所必需的,它只需要一个数组作为参数,不适合目录.目前,我们需要手动读取图像的子集以使其适合目录.一个想法是我们可以改变fit()以接受生成器本身(flow_from_directory),当然,应该在适合期间禁用标准化.

希望这可以帮助...