Pytorch BERT 输入梯度

Pra*_*kar 0 gradient deep-learning pytorch bert-language-model huggingface-transformers

我正在尝试从 pytorch 中的 BERT 模型获取输入梯度。我怎样才能做到这一点?假设 y' = BertModel(x)。我试图找到 $d(loss(y,y'))/dx$

Ene*_*şık 5

Bert 模型的问题之一是,您的输入主要包含令牌 ID,而不是令牌嵌入,这使得获取梯度变得困难,因为令牌 ID 和令牌嵌入之间的关系已中断。要解决此问题,您可以使用令牌嵌入。

# get your batch data: token_id, mask and labels
token_ids, mask, labels = batch
  
# get your token embeddings
token_embeds=BertModel.bert.get_input_embeddings().weight[token_ids].clone()
# track gradient of token embeddings
token_embeds.requires_grad=True
    
# get model output that contains loss value
outs = BertModel(inputs_embeds=inputs_embeds,labels=labels)
loss=outs.loss
Run Code Online (Sandbox Code Playgroud)

获得损失值后,您可以在这个答案或向后函数中使用 torch.autograd.grad

loss.backward()
grad=token_embeds.grad
Run Code Online (Sandbox Code Playgroud)