迭代器重置时的tensorflow数据集混洗行为

Łuk*_*mek 3 tensorflow tensorflow-datasets

我发现reshuffle_each_iteration参数tf.Dataset.shuffle有点混乱。考虑以下代码:

import tensorflow as tf

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    for epoch in range(3):
        sess.run(it_train_x.initializer)
        print("Starting epoch ", epoch)
        while True:
            try:
                s = sess.run(next_sample)
                print("Sample: ", s)
            except tf.errors.OutOfRangeError:
                break
Run Code Online (Sandbox Code Playgroud)

代码输出:

Starting epoch  0
Sample:  b'trimg2'
Sample:  b'trimg1'
Sample:  b'trimg3'
Sample:  b'trimg4'
Starting epoch  1
Sample:  b'trimg4'
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg1'
Starting epoch  2
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg4'
Sample:  b'trimg1'
Run Code Online (Sandbox Code Playgroud)

即使reshuffle_each_iterationFalse,tensorflow在迭代完数据集后仍会重新组合数据。还有另一种重置迭代器的方法吗?预期的行为是reshuffle_each_iteration什么?

我知道我可以seed每次修复并获得相同的订单,问题是关于reshuffle_each_iteration应该如何工作。

我也知道,使用时间段是更惯用的方式repeat(),但是在我的情况下,每个时间段的实际采样数会有所不同。

KRi*_*ish 6

我怀疑TensorFlow仍会在for循环的每次迭代中重新组合数据集,因为迭代器是在每次迭代时初始化的。每次初始化迭代器时,都会将shuffle函数应用于数据集。

预期的行为是迭代器被初始化一次,并reshuffle_each_iteration允许您选择是否在数据重复时重新洗牌(每次原始数据都经过迭代)。

我不确定如何重新格式化您的代码以处理可变数量的样本,但这是使用repeat()功能修改的代码,以证明我的主张:

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
train_x_dataset = train_x_dataset.repeat(4)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    sess.run(it_train_x.initializer)
    while True:
        try:
            s = sess.run(next_sample)
            print("Sample: ", s)
        except tf.errors.OutOfRangeError:
            break
Run Code Online (Sandbox Code Playgroud)

输出:

Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Run Code Online (Sandbox Code Playgroud)

而如果我设置reshuffle_each_iteration=True,我将得到:

Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg2
Sample:  trimg1
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg4
Sample:  trimg1
Sample:  trimg2
Sample:  trimg3
Run Code Online (Sandbox Code Playgroud)

希望这可以帮助!

编辑:我的主张的进一步证据:TensorFlow代码库中的这两个测试功能。在这种情况下,将使用单发迭代器,因此仅初始化一次。批量大小为10的数据用于大小为10的数据,因此每次调用iterator.get_next()都遍历整个源数据。该代码检查该函数的每个后续调用是否返回相同(改组)的数组。

关于此问题的讨论可以进一步阐明不同迭代器的预期用途和预期行为,并可能帮助您找到解决特定问题的解决方案。