Keras model.fit()与tf.dataset API + validation_data

Mar*_*ail 12 python keras tensorflow

所以我通过以下代码让我的keras模型与tf.Dataset一起工作:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)
Run Code Online (Sandbox Code Playgroud)

但是当我尝试将validation_data参数传递给模型时.适合它告诉我,我不能用它与发电机.有没有办法在使用tf.Dataset时使用验证

例如在tensorflow中,我可以执行以下操作:

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)
Run Code Online (Sandbox Code Playgroud)

然后只需使用sess.run(init_op_train)sess.run(init_op_valid)在数据集之间切换

我尝试实现一个回调,只做到这一点(切换到验证集,预测和返回),但它告诉我,我不能在回调中使用model.predict

有人可以帮我验证使用Keras + Tf.Dataset

编辑:将答案纳入代码

所以最终对我有用的,多亏了选定的答案:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels
Run Code Online (Sandbox Code Playgroud)

这不会引入任何开销

W. *_*am 4

我通过使用 fit_genertor 解决了这个问题。我在这里找到了解决方案。我应用了@Dat-Nguyen 的解决方案。

您只需创建两个迭代器,一个用于训练,一个用于验证,然后创建自己的生成器,您将从数据集中提取批次并以 (batch_data, batch_labels) 形式提供数据。最后在 model.fit_generator 中,您将传递 train_generator 和validation_generator。