PyTorch torch.no_grad() 与 requires_grad=False

sta*_*010 6 python machine-learning pytorch bert-language-model huggingface-transformers

我正在学习使用 Huggingface Transformers 库中的 BERT NLP 模型(特征提取器)的PyTorch 教程。有两段我不明白的梯度更新相关代码。

(1) torch.no_grad()

本教程有一个类,其中该forward()函数torch.no_grad()围绕对 BERT 特征提取器的调用创建一个块,如下所示:

bert = BertModel.from_pretrained('bert-base-uncased')

class BERTGRUSentiment(nn.Module):
    
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def forward(self, text):
        with torch.no_grad():
            embedded = self.bert(text)[0]
Run Code Online (Sandbox Code Playgroud)

(2) param.requires_grad = False

在同一教程中还有另一部分冻结了 BERT 参数。

for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False
Run Code Online (Sandbox Code Playgroud)

我什么时候需要(1)和/或(2)?

  • 如果我想使用冻结的 BERT 进行训练,是否需要同时启用两者?
  • 如果我想训练让 BERT 更新,我是否需要同时禁用两者?

另外,我运行了所有四种组合,发现:

   with torch.no_grad   requires_grad = False  Parameters  Ran
   ------------------   ---------------------  ----------  ---
a. Yes                  Yes                      3M        Successfully
b. Yes                  No                     112M        Successfully
c. No                   Yes                      3M        Successfully
d. No                   No                     112M        CUDA out of memory
Run Code Online (Sandbox Code Playgroud)

有人可以解释一下发生了什么吗?为什么我得到CUDA out of memory(d) 而不是 (b)?两者都有 112M 的可学习参数。

den*_*ger 13

这是一个较旧的讨论,多年来略有变化(主要是由于作为一种模式的目的。已经可以在 Stackoverflow 上with torch.no_grad()找到也可以回答您的问题的优秀答案。 但是,由于原始问题很大不同的是,我将避免标记为重复,特别是由于第二部分有关内存。

这里no_grad给出了初步解释:

with torch.no_grad()是一个上下文管理器,用于防止计算梯度[...]。

requires_grad另一方面被使用

冻结模型的一部分并训练其余部分[...]。

再次来源SO 帖子

本质上,requires_grad您只是禁用网络的一部分,而根本no_grad不会存储任何梯度,因为您可能将其用于推理而不是训练。
为了分析参数组合的行为,让我们调查一下发生了什么:

  • a)并且b)根本不存储任何梯度,这意味着无论参数数量有多少,您都有更多的可用内存,因为您不会保留它们用于潜在的向后传递。
  • c)必须存储前向传递以供以后的反向传播使用,但是,仅存储有限数量的参数(300 万),这使得这仍然是可以管理的。
  • d)然而,需要存储所有 1.12 亿个参数的前向传递,这会导致内存不足。