如何加载预训练的 PyTorch 模型?

Pen*_*uin 7 python pytorch

我下面这个关于保存和载入检查站指南。然而,有些事情是不对的。我的模型会训练并且参数会在训练阶段正确更新。但是,加载检查点时似乎出现了问题。也就是说,不再更新参数。

我的型号:

import torch
import torch.nn as nn
import torch.optim as optim

PATH = 'test.pt'

class model(nn.Module): 
        def __init__(self):
            super(model, self).__init__()
            self.a = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            self.b = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            self.c = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            #print(self.a, self.b, self.c)

        def load(self):
          try:
            checkpoint = torch.load(PATH)  
            print('\nloading pre-trained model...')
            self.a = checkpoint['a']
            self.b = checkpoint['b']
            self.c = checkpoint['c']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(self.a, self.b, self.c)
          except: #file doesn't exist yet
            pass

        @property
        def b_opt(self):
            return torch.tanh(self.b)*2

        def train(self):
          print('training...')
          for epoch in range(3):
            print(self.a, self.b, self.c)
            for r in range(5):
              optimizer.zero_grad()
              loss = torch.square(5 * (r > 2) * (3) - model_net.a * torch.sigmoid((r - model_net.b)) * (model_net.c))
              loss.backward(retain_graph=True) #accumulate gradients

            #checkpoint save
            torch.save({
                'model': model_net.state_dict(),
                'a': model_net.a,
                'b': model_net.b,
                'c': model_net.c,
                'optimizer_state_dict': optimizer.state_dict(),
                }, PATH)

            
            optimizer.step() 
          


model_net = model()
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)


print(model_net.a)
print(model_net.b)
print(model_net.c)
Run Code Online (Sandbox Code Playgroud)

这打印

Parameter containing:
tensor([0.4214], requires_grad=True)
Parameter containing:
tensor([0.3862], requires_grad=True)
Parameter containing:
tensor([0.8812], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)

然后我运行model_net.train()以查看参数正在更新并且输出:

training...
Parameter containing:
tensor([0.9990], requires_grad=True) Parameter containing:
tensor([0.1580], requires_grad=True) Parameter containing:
tensor([0.1517], requires_grad=True)
Parameter containing:
tensor([1.0990], requires_grad=True) Parameter containing:
tensor([0.0580], requires_grad=True) Parameter containing:
tensor([0.2517], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)

运行model_net.load()输出:

loading pre-trained model...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)

最后,model_net.train()再次运行输出:

training...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)

更新 1
按照@jhso 的建议,我将负载更改为:

def load(self):
  try:
    checkpoint = torch.load(PATH)  
    print('\nloading pre-trained model...')
    self.load_state_dict(checkpoint['model'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(self.a, self.b, self.c)
  except: #file doesn't exist yet
    pass
Run Code Online (Sandbox Code Playgroud)

这几乎似乎有效(网络现在正在训练),但我认为优化器没有正确加载。那是因为它没有通过线self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
你可以看到,因为它print(self.a, self.b, self.c)在我跑步时没有

model_net.load()
Run Code Online (Sandbox Code Playgroud)