如何在pytorch中处理多重损失?

hei*_*hei 9 python pytorch

在此输入图像描述

比如这个,我想用一些辅助损失来提升我的模特表现.
哪个类型代码可以在pytorch中实现它?

#one
loss1.backward()
loss2.backward()
loss3.backward()
optimizer.step()
#two
loss1.backward()
optimizer.step() 
loss2.backward()
optimizer.step() 
loss3.backward()
optimizer.step()   
#three
loss = loss1+loss2+loss3
loss.backward()
optimizer.step()
Run Code Online (Sandbox Code Playgroud)

感谢您的回答!

Shi*_*han 11

第一种和第三种尝试完全相同且正确,而第二种方法则完全错误。

原因是,在Pytorch中,低层渐变不会 被后续backward()调用“覆盖” ,而是被累积或求和。这使第一种方法和第三种方法相同,但是如果您使用低内存的GPU / RAM,则第一种方法可能更可取,因为具有立即backward() + step()调用的1024个批处理大小与具有128个backward()调用的8个批处理和具有8个调用的批处理是相同step()的结束。

为了说明这一点,下面是一个简单的示例。我们希望同时使张量x最接近[40,50,60]

x = torch.tensor([1.0],requires_grad=True)
loss1 = criterion(40,x)
loss2 = criterion(50,x)
loss3 = criterion(60,x)
Run Code Online (Sandbox Code Playgroud)

现在是第一种方法:(我们tensor.grad用来获取张量的当前梯度x

loss1.backward()
loss2.backward()
loss3.backward()

print(x.grad)
Run Code Online (Sandbox Code Playgroud)

输出:(tensor([-294.])编辑:retain_graph=True在前两个backward调用中添加了更复杂的计算图)

第三种方法:

loss = loss1+loss2+loss3
loss.backward()
print(x.grad)
Run Code Online (Sandbox Code Playgroud)

再次输出是: tensor([-294.])

第二种方法有所不同,因为我们不会opt.zero_grad在调用step()方法之后调用。这意味着在所有3个呼叫中都使用step了首次backward呼叫的梯度。例如,如果3个损失5,1,4为相同的重量提供梯度,而不是10(= 5 + 1 + 4),那么您的体重将具有5*3+1*2+4*1=21梯度。

我确实同意这个结论,但是如果没有内存问题,请使用第三种方法。进一步阅读:链接1链接2


Kon*_*kos 5

- 对第一种方法的评论已删除,请参阅其他答案 -

你的第二种方法需要你反向传播retain_graph=True,这会产生大量的计算成本.此外,它是错误的,因为您将使用第一个优化器步骤更新网络权重,然后您的下一个backward()调用将在更新之前计算渐变,这意味着该second step()调用将在您的更新中插入噪声.另一方面,如果您forward()通过更新的权重执行另一个反向传播调用,则最终会进行异步优化,因为第一个层将使用第一个层更新一次step(),然后再针对每个后续step()调用再次更新(本质上没有错误) ,但效率低下,可能不是你想要的第一个).

长话短说,走的路是最后一种方法.将每个损失减少为标量,将损失相加并反向传播所造成的损失.边注; 确保你的减少方案有意义(例如,如果你使用减少='总和'并​​且损失对应于多标签分类,请记住每个目标的类数不同,因此每个损失所贡献的相对权重也会不同)


pro*_*sti 5

第三次尝试是最好的。

\n

两种不同的损失函数

\n

如果你有两个不同的损失函数,分别完成它们的前向,最后你可以做到(loss1 + loss2).backward()。它\xe2\x80\x99s 效率更高一些,跳过了相当多的计算。

\n

额外提示:将损失求和

\n

在您的代码中您想要执行以下操作:

\n
loss_sum += loss.item()\n
Run Code Online (Sandbox Code Playgroud)\n

以确保您不会跟踪所有损失的历史记录。

\n

item()将破坏图表,从而允许其从循环的一次迭代中释放到下一次。你也可以用detach()同样的方法。

\n