使用 ImageDataGenerator 训练的 Keras 模型进行预测的正确方法

Sre*_* TP 5 machine-learning keras

我通过在 Keras 中使用 ImageDataGenerator 训练了一个应用一些图像增强的模型,如下所示:

train_datagen = ImageDataGenerator(
        rotation_range=60, 
        width_shift_range=0.1,
        height_shift_range=0.1, 
        horizontal_flip=True)
train_datagen.fit(x_train)

history = model.fit_generator(
    train_datagen.flow(x_train, y_train, batch_size=7),
    steps_per_epoch=600,
    epochs=epochs,
    callbacks=callbacks_list
)
Run Code Online (Sandbox Code Playgroud)

我应该如何使用此模型进行预测?通过使用model.predict()如下所示?

predictions = model.predict(x_test)
Run Code Online (Sandbox Code Playgroud)

或者我应该在未标记的model.predict_generator()地方使用 ImageDataGenerator 的x_test地方x_test

如果我使用predict_generator():怎么做?

两种方法有什么区别?

pet*_*ich 4

predict_generator()是一个方便的函数,可以更轻松地加载图像并应用与训练样本相同的预处理。我建议使用它而不是model.predict.

在你的情况下,只需执行以下操作:

test_gen = ImageDataGenerator()
predictions = model.predict_generator(test_gen.flow(# ... your params here ... #))
Run Code Online (Sandbox Code Playgroud)