Nag*_*S N 5 python pre-trained-model pytorch
我有两个网络,我需要将它们连接起来以形成我的完整模型。然而,我的第一个模型是预先训练的,我需要在训练完整模型时使其不可训练。我怎样才能在 PyTorch 中实现这一目标。
我可以使用这个答案连接两个模型
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.fc1 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyModelB(nn.Module):
def __init__(self):
super(MyModelB, self).__init__()
self.fc1 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
def forward(self, x):
x1 = self.modelA(x)
x2 = self.modelB(x1)
return x2
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)
Run Code Online (Sandbox Code Playgroud)
基本上,我想modelA
在训练 Ensemble 模型时加载预训练的模型并使其不可训练。
一种简单的方法是更新detach
您不想更新的模型的输出张量,并且它不会将梯度反向传播到连接的模型。在您的情况下,您可以简单地在模型的前向函数中detach x2
连接之前进行张量,以保持权重不变。x1
MyEnsemble
modelB
因此,新的转发函数应如下所示:
def forward(self, x1, x2):
x1 = self.modelA(x1)
x2 = self.modelB(x2)
x = torch.cat((x1, x2.detach()), dim=1) # Detaching x2, so modelB wont be updated
x = self.classifier(F.relu(x))
return x
Run Code Online (Sandbox Code Playgroud)
您可以通过设置为 false 来冻结不想训练的模型的所有参数requires_grad
。像这样:
for param in model.parameters():
param.requires_grad = False
Run Code Online (Sandbox Code Playgroud)
这应该对你有用。
另一种方法是在火车循环中处理这个问题:
modelA = MyModelA()
modelB = MyModelB()
criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)
for epoch in range(epochs):
for samples, targets in dataloader:
optimizerB.zero_grad()
x = modelA.train()(samples)
predictions = modelB.train()(samples)
loss = criterionB(predictions, targets)
loss.backward()
optimizerB.step()
Run Code Online (Sandbox Code Playgroud)
因此,您将 modelA 的输出传递给 modelB,但仅优化 modelB。
归档时间: |
|
查看次数: |
11532 次 |
最近记录: |