wil*_*007 4 keras tensorflow2.0 tensorflow2.x
我有以下代码,
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
Run Code Online (Sandbox Code Playgroud)
现在model.fit_generator定义如下:
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
Run Code Online (Sandbox Code Playgroud)
现在model.fit_generator已过时,什么是改变的正确方法model.fit_generator,以model.fit在这种情况下?
您只需要更改model.fit_generator()为model.fit().
从 TensorFlow 2.1 开始,model.fit()也接受生成器作为输入。就如此容易。
来自 TensorFlow 的官方文档:
警告:此功能已弃用。它将在未来版本中删除。更新说明:请使用支持生成器的Model.fit。
小智 5
去掉“generator=”。
model.fit_generator(generator=train_generator,
steps_per_epoch=2048//36, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
Run Code Online (Sandbox Code Playgroud)
model.fit(train_generator,
steps_per_epoch=2048 // 128, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3247 次 |
| 最近记录: |