将 model.fit_generator 转换为 model.fit

wil*_*007 4 keras tensorflow2.0 tensorflow2.x

我有以下代码,

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
Run Code Online (Sandbox Code Playgroud)

现在model.fit_generator定义如下:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)
Run Code Online (Sandbox Code Playgroud)

现在model.fit_generator已过时,什么是改变的正确方法model.fit_generator,以model.fit在这种情况下?

Tim*_*lin 5

您只需要更改model.fit_generator()model.fit().

从 TensorFlow 2.1 开始,model.fit()也接受生成器作为输入。就如此容易。

来自 TensorFlow 的官方文档:

警告:此功能已弃用。它将在未来版本中删除。更新说明:请使用支持生成器的Model.fit。


小智 5

去掉“generator=”。

旧训练:
model.fit_generator(generator=train_generator,
    steps_per_epoch=2048//36, epochs=10,
    validation_data=validation_generator, validation_steps=832//16)
Run Code Online (Sandbox Code Playgroud)
新训练:
model.fit(train_generator,
    steps_per_epoch=2048 // 128, epochs=10,
    validation_data=validation_generator, validation_steps=832//16)
Run Code Online (Sandbox Code Playgroud)

  • 这个答案提供了哪些被接受的答案没有提供的内容? (2认同)
  • @Dragonthoughts 因为它是声明性的,而接受的则是必须的 (2认同)