使用张量流训练 CNN 时如何修复“OutOfRangeError:序列结束”错误?

R0b*_*ert 4 python-3.x tensorflow tensorflow-datasets

我正在尝试使用我自己的数据集训练 CNN。我一直在使用 tfrecord 文件和 tf.data.TFRecordDataset API 来处理我的数据集。它适用于我的训练数据集。但是当我尝试对我的验证数据集进行批处理时,出现了“OutOfRangeError: End of sequence”的错误。上网浏览后,我以为是验证集的batch size问题,我一开始设置为32。但是在我将其更改为 2 之后,代码运行了大约 9 个 epoch,并且错误再次出现。

我使用输入函数来处理数据集,代码如下:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()

    return features, labels
Run Code Online (Sandbox Code Playgroud)

对于训练集,“batch_size”设置为 128,“num_epochs”设置为 None,这意味着无限重复。对于验证集,“batch_size”设置为 32(后来设置为 2,仍然无效),“num_epochs”设置为 1,因为我只想通过验证集一次。我可以保证验证集包含足够的 epoch 数据。因为我已经尝试了下面的代码并且没有引发任何错误:

with tf.Session() as sess:
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    for i in range(450):
        sess.run([features, labels])
        print(labels.shape)
Run Code Online (Sandbox Code Playgroud)

在上面的代码中,当我将数字 450 更改为 500 或任何更大的值时,它会引发“OutOfRangeError”。这可以确认我的验证数据集包含足够 450 次迭代的数据,批量大小为 32。

我尝试对验证集使用较小的批量大小(即 2),但仍然存在相同的错误。我可以让代码在 input_fn 中将“num_epochs”设置为“None”来运行以进行验证,但这似乎不是验证的工作方式。请问有什么帮助吗?

Oli*_*ene 5

这种行为很正常。从 Tensorflow 文档:

如果迭代器到达数据集的末尾,则执行该Iterator.get_next()操作将引发tf.errors.OutOfRangeError. 此后迭代器将处于不可用状态,如果您想进一步使用它,则必须再次初始化它。

设置的时候不报错的原因dataset.repeat(None)是因为数据集是无限重复的,所以永远不会用完。

要解决您的问题,您应该将代码更改为:

n_steps = 450
...    

with tf.Session() as sess:
    # Training
    features, labels = input_fn(True, training_list, 32, 1, 1)

    for step in range(n_steps):
        sess.run([features, labels])
        ...
    ...
    # Validation
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    try:
        sess.run([features, labels])
        ...
    except tf.errors.OutOfRangeError:
        print("End of dataset")  # ==> "End of dataset"
Run Code Online (Sandbox Code Playgroud)

您还可以对 input_fn 进行一些更改以在每个时期运行评估:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_initializable_iterator()
    return iterator

n_epochs = 10
freq_eval = 1

training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()

val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()

with tf.Session() as sess:
    # Training
    sess.run(training_iterator.initializer)
    for epoch in range(n_epochs):
        try:
            sess.run([training_features, training_labels])
        except tf.errors.OutOfRangeError:
            pass

        # Validation
        if (epoch+1) % freq_eval == 0:
            sess.run(val_iterator.initializer)
            try:
                sess.run([val_features, val_labels])
            except tf.errors.OutOfRangeError:
                pass
Run Code Online (Sandbox Code Playgroud)

如果您想更好地了解幕后发生的事情,我建议您仔细查看此官方指南