TF keras API 与 TF 数据集问题-steps_per_epoch 参数问题

Jos*_*bel 6 python deep-learning tensorflow tensorflow-datasets

当试图拟合 Keras 模型时,该模型用tensorflow.kerasAPI编写并带有tf.Dataset诱导迭代器,该模型抱怨steps_per_epoch参数,即使我已将此参数设置为具体值。

下面是我的模型类

import tensorflow as tf
import numpy as np
from typing import Union, List
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import layers
from tftools import TFTools


class TestServe():
    def __init__(self, tfrecords: Union[List[tf.train.Example], tf.train.Example], batch_size: int = 10, input_shape: tuple = (64, 23)) -> None:
        self.tfrecords = tfrecords
        self.batch_size = batch_size
        self.input_shape = input_shape

    def get_model(self):
        ins = layers.Input(shape=(64, 23))

        l = layers.Reshape((*self.input_shape, 1))(ins)
        l = layers.Conv2D(8, (30, 23), padding='same', activation='relu')(l)
        l = layers.MaxPool2D((4, 5), strides=(4, 5))(l)
        l = layers.Conv2D(16, (3, 3), padding='same', activation='relu')(l)
        l = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(l)
        l = layers.MaxPool2D((2, 2), strides=(2, 2))(l)
        l = layers.Flatten()(l)

        out = layers.Dense(1, activation='softmax')(l)
        return tf.keras.models.Model(ins, out)

    def train(self):

        # Create Dataset
        dataset = TFTools.create_dataset(self.tfrecords)
        dataset = dataset.repeat(6).batch(self.batch_size)

        val_iterator = dataset.take(300).make_one_shot_iterator()
        train_iterator = dataset.skip(300).make_one_shot_iterator()

        model = self.get_model()
        model.summary()
        model.compile(optimizer='rmsprop',
                      loss='binary_crossentropy', metrics=['accuracy'])
        model.fit(train_iterator, validation_data=val_iterator,
                  epochs=10, verbose=1, steps_per_epoch=20)

    def predict(self, X: np.array) -> np.array:
        pass

ts = TestServe(['./ok.tfrecord', './nok.tfrecord'])
ts.train()
Run Code Online (Sandbox Code Playgroud)

但是当我开始训练时,在第一批完成之前,我从 tensorflow 中得到一个异常

2019-06-13 14:22:25.393398: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1995445000 Hz
2019-06-13 14:22:25.393681: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x2f7d120 executing computations on platform Host. Devices:
2019-06-13 14:22:25.393708: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
Epoch 1/2
19/20 [===========================>..] - ETA: 0s - loss: 1.1921e-07 - acc: 1.0000Traceback (most recent call last):
  File "TestServe.py", line 62, in <module>
    ts.train()
  File "TestServe.py", line 56, in train
    epochs=2, verbose=1, callbacks=callbacks, steps_per_epoch=20) #The steps_per_epoch is typically samples_per_epoch / batch_size
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 880, in fit
    validation_steps=validation_steps)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 364, in model_iteration
    validation_in_fit=True)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 202, in model_iteration
    steps_per_epoch)
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 76, in _get_num_samples_or_steps
    'steps_per_epoch')
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 230, in check_num_samples
    if check_steps_argument(ins, steps, steps_name):
  File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 960, in check_steps_argument
    input_type=input_type_str, steps_name=steps_name))
ValueError: When using data tensors as input to a model, you should specify the `steps_per_epoch` argument.
Run Code Online (Sandbox Code Playgroud)

原始数据集包含大约 1500 个样本,但我想将多个 tfrecord 文件加入到 TFRecordDataset 中,因此我没有关于长度的信息。

有人见过类似的东西吗?我不知道去哪里寻求帮助,因为tf.kerasAPI 相对较新。该create_dataset函数仅返回使用正确解析函数映射的数据集。

Jos*_*bel 10

找到了解决办法。

不仅存在steps_per_epoch,而且validation_steps参数,你也必须指定。