学习率调度器和优化器之间有什么关系?

Pen*_*uin 7 python pytorch

如果我有一个模型:

import torch
import torch.nn as nn
import torch.optim as optim

class net_x(nn.Module): 
        def __init__(self):
            super(net_x, self).__init__()
            self.fc1=nn.Linear(2, 20) 
            self.fc2=nn.Linear(20, 20)
            self.out=nn.Linear(20, 4) 

        def forward(self, x):
            x=self.fc1(x)
            x=self.fc2(x)
            x=self.out(x)
            return x

nx = net_x()
Run Code Online (Sandbox Code Playgroud)

然后我定义我的输入、优化器(使用lr=0.1)、调度程序(使用base_lr=1e-3)和训练:

r = torch.tensor([1.0,2.0])
optimizer = optim.Adam(nx.parameters(), lr = 0.1)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-3, max_lr=0.1, step_size_up=1, mode="triangular2", cycle_momentum=False)

path = 'opt.pt'
for epoch in range(10):
    optimizer.zero_grad()
    net_predictions = nx(r)
    loss = torch.sum(torch.randint(0,10,(4,)) - net_predictions)
    loss.backward()
    optimizer.step()
    scheduler.step()
    print('loss:' , loss)
    
    #save state dict
    torch.save({    'epoch': epoch,
                    'net_x_state_dict': nx.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),    
                    }, path)
#loading state dict
checkpoint = torch.load(path)        
nx.load_state_dict(checkpoint['net_x_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])
Run Code Online (Sandbox Code Playgroud)

优化器似乎采用了调度器的学习率

for g in optimizer.param_groups:
    print(g)
>>>
{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'initial_lr': 0.001, 'params': [Parameter containing:
Run Code Online (Sandbox Code Playgroud)

学习率调度器会覆盖优化器吗?它如何连接到它?试图理解它们之间的关系(即它们如何相互作用等)

Shi*_*hir 10

TL;DR: LR 调度器包含优化器作为成员,并显式更改其参数学习率。


正如PyTorch 官方文档中提到的中提到的,学习率调度器在其构造函数中接收优化器作为参数,因此可以访问其参数。

常见的用途是在每个 epoch 之后更新 LR:

scheduler = ... # initialize some LR scheduler
for epoch in range(100):
    train(...) # here optimizer.step() is called numerous times.
    validate(...)
    scheduler.step()
Run Code Online (Sandbox Code Playgroud)

所有优化器都继承自一个公共父类torch.nn.Optimizer,并使用step,并使用为每个优化器实现的方法

类似地,所有 LR 调度程序(除了ReduceLROnPlateau)都继承自名为 的公共父类_LRScheduler。观察其源代码发现,step该类在方法中确实改变了优化器参数的LR:

...
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
...
Run Code Online (Sandbox Code Playgroud)

  • 如果你想覆盖 LR 调度器的学习率,你可以覆盖它的 _last_lr 成员,下次你使用调度器采取步骤时,它将根据这个新的 LR 进行更新。但是,这不是一个好的做法,因为您覆盖了私有成员。考虑使用“MultiStepLR”,它使您能够在每个调度程序的步骤中手动选择 LR。 (3认同)