小编Bau*_*nza的帖子

keras:如何通过model.train_on_batch()使用学习率衰减

在我当前的项目中,我正在使用keras train_on_batch()函数进行训练,因为该fit()函数不支持GAN所需的生成器和鉴别器的交替训练。使用(例如)亚当优化器,我必须在构造函数中指定学习率衰减,optimizer = Adam(decay=my_decay)然后将其递给模型编译方法。如果我以后使用模型的fit()功能,此方法就可以正常工作,因为这需要在内部对训练重复次数进行计数,但是我不知道如何使用类似以下的结构自己设置此值

counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        loss = model.train_on_batch(x, y)    # how to use the current learning rate?
Run Code Online (Sandbox Code Playgroud)

具有一些计算学习率的功能。如何手动设置当前学习率?

抱歉,如果这篇文章中有错误,这是我的第一个问题。

谢谢您的帮助。

python deep-learning keras

6
推荐指数
1
解决办法
1561
查看次数

标签 统计

deep-learning ×1

keras ×1

python ×1