finetuning Bert時的權重衰減

權重衰減

L2正則化的目的就是爲了讓權重衰減到更小的值,在一定程度上減少模型過擬合的問題,所以權重衰減也叫L2正則化。
在這裏插入圖片描述

Bert中的權重衰減

並不是所有的權重參數都需要衰減,比如bias,和LayerNorm.weight就不需要衰減。

from transformers import BertConfig, BertForSequenceClassification, AdamW
import torch
import torch.nn as nn

# 使用GPU
# 通過model.to(device)的方式使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2, hidden_dropout_prob=0.2)
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=config)
model.to(device)

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# 對應optimizer_grouped_parameters中的第一個dict,這裏面的參數需要權重衰減
need_decay = []
for n, p in model.named_parameters():
    if not any(nd in n for nd in no_decay):
        need_decay.append(n)
        
# 對應optimizer_grouped_parameters中的第二個dict,這裏面的參數不需要權重衰減
not_decay = []
for n, p in model.named_parameters():
    if any(nd in n for nd in no_decay):
        not_decay.append(n)
        
# AdamW是實現了權重衰減的優化器
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
criterion = nn.CrossEntropyLoss()

print("Done")
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章