我下面这个关于保存和载入检查站指南。然而,有些事情是不对的。我的模型会训练并且参数会在训练阶段正确更新。但是,加载检查点时似乎出现了问题。也就是说,不再更新参数。
我的型号:
import torch
import torch.nn as nn
import torch.optim as optim
PATH = 'test.pt'
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.a = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.b = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.c = torch.nn.Parameter(torch.rand(1, requires_grad=True))
#print(self.a, self.b, self.c)
def load(self):
try:
checkpoint = torch.load(PATH)
print('\nloading pre-trained model...')
self.a = checkpoint['a']
self.b = checkpoint['b']
self.c = checkpoint['c']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(self.a, self.b, self.c)
except: #file doesn't exist yet
pass
@property
def b_opt(self):
return torch.tanh(self.b)*2
def train(self):
print('training...')
for epoch in range(3):
print(self.a, self.b, self.c)
for r in range(5):
optimizer.zero_grad()
loss = torch.square(5 * (r > 2) * (3) - model_net.a * torch.sigmoid((r - model_net.b)) * (model_net.c))
loss.backward(retain_graph=True) #accumulate gradients
#checkpoint save
torch.save({
'model': model_net.state_dict(),
'a': model_net.a,
'b': model_net.b,
'c': model_net.c,
'optimizer_state_dict': optimizer.state_dict(),
}, PATH)
optimizer.step()
model_net = model()
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)
print(model_net.a)
print(model_net.b)
print(model_net.c)
Run Code Online (Sandbox Code Playgroud)
这打印
Parameter containing:
tensor([0.4214], requires_grad=True)
Parameter containing:
tensor([0.3862], requires_grad=True)
Parameter containing:
tensor([0.8812], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)
然后我运行model_net.train()以查看参数正在更新并且输出:
training...
Parameter containing:
tensor([0.9990], requires_grad=True) Parameter containing:
tensor([0.1580], requires_grad=True) Parameter containing:
tensor([0.1517], requires_grad=True)
Parameter containing:
tensor([1.0990], requires_grad=True) Parameter containing:
tensor([0.0580], requires_grad=True) Parameter containing:
tensor([0.2517], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)
运行model_net.load()输出:
loading pre-trained model...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)
最后,model_net.train()再次运行输出:
training...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)
更新 1。
按照@jhso 的建议,我将负载更改为:
def load(self):
try:
checkpoint = torch.load(PATH)
print('\nloading pre-trained model...')
self.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(self.a, self.b, self.c)
except: #file doesn't exist yet
pass
Run Code Online (Sandbox Code Playgroud)
这几乎似乎有效(网络现在正在训练),但我认为优化器没有正确加载。那是因为它没有通过线self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])。
你可以看到,因为它print(self.a, self.b, self.c)在我跑步时没有
model_net.load()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
303 次 |
| 最近记录: |