Mar*_*ldt 7 python machine-learning deep-learning keras
我想在Keras中使用Image augmentation.我当前的代码如下所示:
# define image augmentations
train_datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
zca_whitening=True)
# generate image batches from directory
train_datagen.flow_from_directory(train_dir)
Run Code Online (Sandbox Code Playgroud)
当我用这个运行模型时,我收到以下错误:
"ImageDataGenerator specifies `featurewise_std_normalization`, but it hasn't been fit on any training data."
Run Code Online (Sandbox Code Playgroud)
但我没有找到有关如何使用明确的信息train_dataget.fit()一起flow_from_directory.
谢谢您的帮助.马里奥
des*_*aut 14
你是对的,文档在这方面不是很有启发性......
你需要的实际上是一个4步骤的过程:
flow_from_directory()fit_generator()以下是假设图像分类案例的必要代码:
# define data augmentation configuration
train_datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
zca_whitening=True)
# fit the data augmentation
train_datagen.fit(x_train)
# setup generator
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical')
# train model
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples,
epochs=epochs,
validation_data=validation_generator, # optional - if used needs to be defined
validation_steps=nb_validation_samples)
Run Code Online (Sandbox Code Playgroud)
显然,有几个参数来定义(train_data_dir,nb_train_samples等等),但希望你的想法.
如果您还需要使用a validation_generator,如我的示例所示,这应该与您的定义相同train_generator.
更新(评论后)
第2步需要一些讨论; 这里x_train是理想情况下应该适合主存储器的实际数据.另外(文档),这一步是
仅在featurewise_center或featurewise_std_normalization或zca_whitening时才需要.
然而,在许多现实世界中,所有训练数据都适合记忆的要求显然是不现实的.在这种情况下,如何集中/规范化/白化数据本身就是一个(巨大的)子领域,可以说是存在大数据处理框架(如Spark)的主要原因.
那么,在这里做什么呢?那么,在这种情况下,下一个合乎逻辑的行动是对数据进行采样 ; 事实上,这正是社区所建议的 - 这里是Keras创建者Francois Chollet关于使用像Imagenet这样的大型数据集:
Run Code Online (Sandbox Code Playgroud)datagen.fit(X_sample) # let's say X_sample is a small-ish but statistically representative sample of your data
还有一个关于扩展的持续公开讨论的另一个引用ImageDataGenerator(强调增加):
功能是标准化和ZCA所必需的,它只需要一个数组作为参数,不适合目录.目前,我们需要手动读取图像的子集以使其适合目录.一个想法是我们可以改变
fit()以接受生成器本身(flow_from_directory),当然,应该在适合期间禁用标准化.
希望这可以帮助...