TensorFlow 模型拟合和 train_on_batch 之间的区别

Jac*_*onZ 2 python machine-learning reinforcement-learning deep-learning tensorflow

我正在构建一个普通的 DQN 模型来玩 OpenAI 健身房 Cartpole 游戏。

然而,在训练步骤中,我将状态作为输入,目标 Q 值作为标签,如果我使用model.fit(x=states, y=target_q),它工作正常并且代理最终可以很好地玩游戏,但是如果我使用model.train_on_batch(x=states, y=target_q),损失将不会减少并且模型不会比随机策略更好地玩游戏。

我想知道fit和 和 有train_on_batch什么区别?据我了解,在fit后台调用train_on_batch批处理大小为 32 应该没有区别,因为指定批处理大小等于我输入的实际数据大小没有区别。

如果需要更多上下文信息来回答这个问题,完整代码在这里:https : //github.com/ultronify/cartpole-tf

Jak*_*kub 11

model.fit将训练 1 个或多个 epoch。这意味着它将训练多个批次。model.train_on_batch,顾名思义,只训练一批。

举一个具体的例子,假设你正在用 10 张图像训练一个模型。假设您的批量大小为 2。model.fit将在所有 10 张图像上进行训练,因此它将更新梯度 5 次。(您可以指定多个时期,因此它会遍历您的数据集。)model.train_on_batch将执行梯度的一次更新,因为您只批量提供模型。model.train_on_batch如果您的批量大小为 2,您将提供两张图片。

如果我们假设在幕后model.fit调用model.train_on_batch(尽管我不认为它确实如此),那么model.train_on_batch将被多次调用,可能是在循环中。这里用伪代码来解释。

def fit(x, y, batch_size, epochs=1):
    for epoch in range(epochs):
        for batch_x, batch_y in batch(x, y, batch_size):
            model.train_on_batch(batch_x, batch_y)
Run Code Online (Sandbox Code Playgroud)