使用 HuggingFace 库在 Pytorch 中训练 BERT 的最后 n% 层(训练 Last 5 BERTLAYER out of 12 。)

Des*_*wal 1 nlp deep-learning torch pytorch huggingface-transformers

Bert 有一个类似于encoder -> 12 BertLayer -> Pooling. 我想训练 Bert 模型的最后 40% 层。我可以将所有图层冻结为:

# freeze parameters
bert = AutoModel.from_pretrained('bert-base-uncased')
for param in bert.parameters():
    param.requires_grad = False

Run Code Online (Sandbox Code Playgroud)

但我想训练最后 40% 的层。当我这样做时len(list(bert.parameters())),它给了我 199。所以让我们假设 79 是参数的 40%。我可以做这样的事情:

for param in list(bert.parameters())[-79:]: # total  trainable 199 Params: 79 is 40%
    param.requires_grad = False
Run Code Online (Sandbox Code Playgroud)

我认为它会冻结前 60% 的层。

另外,有人可以告诉我它会根据架构冻结哪些层吗?

cro*_*oik 5

您可能正在寻找named_pa​​rameters

for name, param in bert.named_parameters():                                            
    print(name)
Run Code Online (Sandbox Code Playgroud)

输出:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
...
Run Code Online (Sandbox Code Playgroud)

named_parameters 还将显示您没有冻结前 60% 而是最后 40%:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
...
Run Code Online (Sandbox Code Playgroud)

输出:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.key.bias
encoder.layer.0.attention.self.value.weight
...
Run Code Online (Sandbox Code Playgroud)

您可以使用以下方法冻结前 60%:

for name, param in bert.named_parameters():
    if param.requires_grad == True:
        print(name) 
Run Code Online (Sandbox Code Playgroud)