pytorch如何设置.requires_grad错误

Qia*_*ang 17 python gradient-descent pytorch

我想把我的一些模型冻结.遵循官方文档:

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.eval()
    print(linear.weight.requires_grad)
Run Code Online (Sandbox Code Playgroud)

但它打印True而不是False.如果我想在eval模式下设置模型,我该怎么办?

iac*_*ppo 25

requires_grad =假

如果要冻结模型的一部分并训练其余部分,可以设置requires_grad要冻结的参数False.

例如,如果您只想保持VGG16的卷积部分是固定的:

model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
    param.requires_grad = False
Run Code Online (Sandbox Code Playgroud)

通过将requires_grad标志切换到False,将不会保存中间缓冲区,直到计算到达某个点,其中操作的一个输入需要梯度.

torch.no_grad()

使用上下文管理器torch.no_grad是实现该目标的另一种方式:在no_grad上下文中requires_grad=False,即使输入具有,计算的所有结果也将具有requires_grad=True.请注意,在之前,您将无法将渐变反向传播到图层no_grad.例如:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():    
    x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)
Run Code Online (Sandbox Code Playgroud)

输出:

(None, None, tensor([[-1.4481, -1.1789],
         [-1.4481, -1.1789]]))
Run Code Online (Sandbox Code Playgroud)

lin1.weight.requires_grad是真的,但是没有计算梯度,因为oepration是在no_grad上下文中完成的.

model.eval()

如果您的目标不是微调,而是在推理模式下设置模型,最方便的方法是使用torch.no_grad上下文管理器.在这种情况下,你还可以设置你的模型来评估模式,这是通过调用实现eval()nn.Module,例如:

model = torchvision.models.vgg16(pretrained=True)
model.eval()
Run Code Online (Sandbox Code Playgroud)

此操作将self.training图层的属性设置为False,在实践中,这将改变操作的行为,例如DropoutBatchNorm在训练和测试时必须表现不同的操作.


ben*_*che 6

为了完成@Salih_Karagoz的答案,您还拥有上下文(此处torch.set_grad_enabled()有更多文档),它可用于在训练/评估模式之间轻松切换:

linear = nn.Linear(1,1)

is_train = False

for param in linear.parameters():
    param.requires_grad = is_train
with torch.set_grad_enabled(is_train):
    linear.eval()
    print(linear.weight.requires_grad)
Run Code Online (Sandbox Code Playgroud)


Sal*_*goz 5

这是方法;

linear = nn.Linear(1,1)

for param in linear.parameters():
    param.requires_grad = False

with torch.no_grad():
    linear.eval()
    print(linear.weight.requires_grad)
Run Code Online (Sandbox Code Playgroud)

输出:错误