如何使用支持生成器的 Model.fit(在 fit_generator 弃用后)

Beh*_*nam 23 python keras tensorflow

Model.fit_generator在 tensorflow 中使用时收到了这个弃用警告:

WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
Run Code Online (Sandbox Code Playgroud)

我该如何使用Model.fit而不是Model.fit_generator

小智 25

Model.fit_generator从目前在rc1 中的tensorflow 2.1.0 开始不推荐使用。您可以在此处找到 tf-2.1.0-rc1 的文档:https : //www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit

如您所见, can 的第一个参数Model.fit采用生成器,因此只需将生成器传递给它即可。


小智 14

文档中所述(强调我的):

x:输入数据。它可能是

  • 一个 Numpy 数组(或类似数组),或数组列表(以防模型有多个输入)。
  • 一个 TensorFlow 张量,或一个张量列表(如果模型有多个输入)。
  • 如果模型具有命名输入,则 dict 将输入名称映射到相应的数组/张量。
  • tf.data 数据集。应该返回(输入,目标)或(输入,目标,sample_weights)的元组
  • 返回(输入、目标)或(输入、目标、样本权重)的生成器或 keras.utils.Sequence。下面给出了迭代器类型(数据集、生成器、序列)的解包行为的更详细描述。

您可以简单地将生成器传递给Model.fit类似于Model.fit_generator

data_gen_train = ImageDataGenerator(rescale=1/255.)

data_gen_valid = ImageDataGenerator(rescale=1/255.)

train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")

valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")

model.fit(train_generator, epochs=2, validation_data=valid_generator) 
Run Code Online (Sandbox Code Playgroud)

  • 你的回答是不正确的,或者,我应该说,只是部分正确。model.fit 不支持 valid_generator (2认同)
  • @sidk是的,似乎“fit”参数“validation_data”不支持数据集、生成器或keras.utils.Sequence,并且当参数“x”是数据集、生成器或keras.utils时不支持“validation_split”。每个“fit”信息的序列位于 https://www.tensorflow.org/api_docs/python/tf/keras/Model (2认同)