Val*_*ier 2 deep-learning pytorch autograd
在 Pytorch 训练中,我使用复合损失函数,定义为:
。为了更新权重 alpha 和 beta,我需要计算三个值:
它们是网络中所有权重的损失项梯度的平均值。
有没有一种有效的方法可以在 pytorch 中编写它?
我的训练代码如下所示:
for epoc in range(1, nb_epochs+1):
#init
optimizer_fo.zero_grad()
#get the current loss
loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
#compute gradients
loss_total.backward(retain_graph=True)
#optimize
optimizer_fo.step()
Run Code Online (Sandbox Code Playgroud)
我的 .loss() 函数直接返回各项之和。我考虑过进行第二次前向传递并独立地向后调用每个损失项,但这会非常昂贵。
torch.autograd.grad您只能通过在网络上多次反向传播来获得梯度的不同项。为了避免对输入执行多次推理,您可以使用torch.autograd.grad效用函数而不是执行传统的向后传递backward。这意味着您不会污染来自不同项的梯度。
这是一个显示基本思想的最小示例:
>>> x = torch.rand(1, 10, requires_grad=True)
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()
Run Code Online (Sandbox Code Playgroud)
然后对每个不合适的项执行一次向后传递。您必须保留除最后一次之外的所有调用的图表:
>>> gradA = torch.autograd.grad(lossA, x, retain_graph=True)
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),)
>>> gradB = torch.autograd.grad(lossB, x)
(tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]),)
Run Code Online (Sandbox Code Playgroud)
此方法有一些限制,因为您以元组形式接收参数的梯度,这不太方便。
backward另一种解决方案是在每次连续调用后缓存梯度backward:
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()
>>> lossA.backward(retain_graph=True)
Run Code Online (Sandbox Code Playgroud)
存储梯度并清除.grad属性(不要忘记这样做,否则 的梯度lossA会污染gradB。在处理多个张量参数时,您必须适应一般情况:
>>> x.gradA = x.grad
>>> x.grad = None
Run Code Online (Sandbox Code Playgroud)
向后传递下一个损失项:
>>> lossB.backward()
>>> x.gradB = x.grad
Run Code Online (Sandbox Code Playgroud)
然后您可以在本地与每个梯度项进行交互(即分别针对每个参数):
>>> x.gradA, x.gradB
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),
tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]))
Run Code Online (Sandbox Code Playgroud)
后一种方法似乎更实用。
这本质上取决于torch.autograd.grad与torch.autograd.backward,即 异地与就地......并且最终取决于您的需求。您可以在此处阅读有关这两个函数的更多信息。
| 归档时间: |
|
| 查看次数: |
2650 次 |
| 最近记录: |