如何使用 ImageDataGenerator 将文件夹拆分为 3 个数据集?

zs2*_*020 8 python keras imagedatagenerator

validation_split 参数能够允许 ImageDataGenerator 将从文件夹读取的数据集拆分为 2 个不同的不相交集。有没有办法使用它创建 3 组训练、验证和评估数据集?

我正在考虑将数据集拆分为 2 个数据集,然后将第二个数据集拆分为另外 2 个数据集

datagen = ImageDataGenerator(validation_split=0.5, rescale=1./255)

train_generator = datagen.flow_from_directory(
    TRAIN_DIR, 
    subset='training'
)

val_generator = datagen.flow_from_directory(
    TRAIN_DIR,
    subset='validation'
)
Run Code Online (Sandbox Code Playgroud)

在这里,我正在考虑使用 val_generator 将验证数据集分成 2 组。一个用于验证,另一个用于评估?我该怎么做呢?

Gio*_*oni 2

我喜欢使用 的flow_from_dataframe()方法ImageDataGenerator,在该方法中我与一个简单的 Pandas DataFrame(可能包含其他功能)交互,而不是与目录交互。但如果您坚持的话,您可以轻松更改我的代码flow_from_directory()

所以这是我的首选函数,例如对于回归任务,我们尝试预测连续的y

def get_generators(train_samp, test_samp, validation_split = 0.1):
    train_datagen = ImageDataGenerator(validation_split=validation_split, rescale = 1. / 255)
    test_datagen = ImageDataGenerator(rescale = 1. / 255)
    
    train_generator = train_datagen.flow_from_dataframe(
        dataframe = images_df[images_df.index.isin(train_samp)],
        directory = images_dir,
        x_col = 'img_file',
        y_col = 'y',
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = True,
        subset = 'training',
        validate_filenames = False
    )
    valid_generator = train_datagen.flow_from_dataframe(
        dataframe = images_df[images_df.index.isin(train_samp)],
        directory = images_dir,
        x_col = 'img_file',
        y_col = 'y',
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = False,
        subset = 'validation',
        validate_filenames = False
    )

    test_generator = test_datagen.flow_from_dataframe(
        dataframe = images_df[images_df.index.isin(test_samp)],
        directory = images_dir,
        x_col = 'img_file',
        y_col = 'y',
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = False,
        validate_filenames = False
    )
    return train_generator, valid_generator, test_generator
Run Code Online (Sandbox Code Playgroud)

注意事项:

  • 我使用两台发电机
  • 函数的输入是训练/测试索引(例如从 Sklearn 接收的train_test_split),用于过滤 DataFrame 索引。
  • 该函数还validation_split为训练生成器提供一个参数
  • images_df是全局内存中某处的 DataFrame,具有适当的列,例如img_filey
  • 无需shuffle验证和测试生成器

这可以进一步推广到多个输出、分类等。