对于无限数据集,每个时期使用的数据是否相同?

aud*_*son 6 python tensorflow

在张量流中,假设我有一个来自生成器的数据集:

dataset = tf.data.Dataset.from_generator(gen...)
Run Code Online (Sandbox Code Playgroud)

这个生成器生成无限的非重复数据(就像无限的非循环小数)。

model.fit(dataset, steps_per_epoch=10000, epochs=5)
Run Code Online (Sandbox Code Playgroud)

现在在这 5 个训练周期内,使用的数据是否相同?即总是来自生成器的前 10000 个项目?而不是 epoch 1 的 0-9999、epoch 2 的 10000-19999 等。

参数呢initial_epoch?如果我设置为1,模型会从第10000项开始训练吗?

model.fit(dataset, steps_per_epoch=10000, epochs=5, initial_epoch=1)
Run Code Online (Sandbox Code Playgroud)

更新: 这个简单的测试表明每次model.fit()调用时数据集都会重置

def gen():
    i = 1
    while True:
        yield np.array([[i]]), np.array([[0]])
        i += 1

ds = tf.data.Dataset.from_generator(gen, output_types=(tf.int32, tf.int32)).batch(3)

x = Input(shape=(1, 1))
model = Model(inputs=x, outputs=x)

model.compile('adam', loss=lambda true, pred: tf.reduce_mean(pred))
for i in range(10):
    model.fit(ds, steps_per_epoch=5, epochs=1)
Run Code Online (Sandbox Code Playgroud)

输出:

1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 9ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 2ms/step - loss: 8.0000
Run Code Online (Sandbox Code Playgroud)

1 次调用 5 个 epoch:

model.fit(ds, steps_per_epoch=5, epochs=5)
Run Code Online (Sandbox Code Playgroud)

输出:

Epoch 1/5
1/5 [=====>........................] - ETA: 0s - loss: 2.0000
5/5 [==============================] - 0s 9ms/step - loss: 8.0000
Epoch 2/5
1/5 [=====>........................] - ETA: 0s - loss: 17.0000
5/5 [==============================] - 0s 2ms/step - loss: 23.0000
Epoch 3/5
1/5 [=====>........................] - ETA: 0s - loss: 32.0000
5/5 [==============================] - 0s 2ms/step - loss: 38.0000
Epoch 4/5
1/5 [=====>........................] - ETA: 0s - loss: 47.0000
5/5 [==============================] - 0s 2ms/step - loss: 53.0000
Epoch 5/5
1/5 [=====>........................] - ETA: 0s - loss: 62.0000
5/5 [==============================] - 0s 2ms/step - loss: 68.0000
Run Code Online (Sandbox Code Playgroud)

Szy*_*zke 2

不,使用的数据不同。steps_per_epoch用于keras确定每个的长度epoch(因为生成器没有长度),因此它知道何时结束训练(或调用检查指针等)。

initial_epoch是为纪元显示的数字,当您想从检查点重新开始训练时有用(请参阅fit 方法),它与数据迭代无关。

如果您将相同的方法传递datasetmodel.fit方法,它将在每次函数调用后重置(感谢信息OP)。