Jac*_*onZ 2 python machine-learning reinforcement-learning deep-learning tensorflow
我正在构建一个普通的 DQN 模型来玩 OpenAI 健身房 Cartpole 游戏。
然而,在训练步骤中,我将状态作为输入,目标 Q 值作为标签,如果我使用model.fit(x=states, y=target_q)
,它工作正常并且代理最终可以很好地玩游戏,但是如果我使用model.train_on_batch(x=states, y=target_q)
,损失将不会减少并且模型不会比随机策略更好地玩游戏。
我想知道fit
和 和 有train_on_batch
什么区别?据我了解,在fit
后台调用train_on_batch
批处理大小为 32 应该没有区别,因为指定批处理大小等于我输入的实际数据大小没有区别。
如果需要更多上下文信息来回答这个问题,完整代码在这里:https : //github.com/ultronify/cartpole-tf
Jak*_*kub 11
model.fit
将训练 1 个或多个 epoch。这意味着它将训练多个批次。model.train_on_batch
,顾名思义,只训练一批。
举一个具体的例子,假设你正在用 10 张图像训练一个模型。假设您的批量大小为 2。model.fit
将在所有 10 张图像上进行训练,因此它将更新梯度 5 次。(您可以指定多个时期,因此它会遍历您的数据集。)model.train_on_batch
将执行梯度的一次更新,因为您只批量提供模型。model.train_on_batch
如果您的批量大小为 2,您将提供两张图片。
如果我们假设在幕后model.fit
调用model.train_on_batch
(尽管我不认为它确实如此),那么model.train_on_batch
将被多次调用,可能是在循环中。这里用伪代码来解释。
def fit(x, y, batch_size, epochs=1):
for epoch in range(epochs):
for batch_x, batch_y in batch(x, y, batch_size):
model.train_on_batch(batch_x, batch_y)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
2791 次 |
最近记录: |