clip_gradient 的作用就是讓權重的更新限制在一個合適的範圍:
- 首先設置一個梯度閾值:clip_gradient
- 在反向傳播中求出各參數的梯度,這裏我們不直接使用梯度進行參數更新,我們求這些梯度的l2範數||g||
- 然後比較||g||與clip_gradient的大小
- 如果前者大,求縮放因子clip_gradient/||g||, 由縮放因子可以看出梯度越大,則縮放因子越小,這樣便很好地控制了梯度的範圍
- 最後將梯度乘上縮放因子便得到最後所需的梯度
pytorch版代碼如下:
#使用梯度剪切,防止梯度爆炸
def clip_gradient(model, clip_norm):
"""Computes a gradient clipping coefficient based on gradient norm."""
#基於梯度範數計算梯度剪切係數。
totalnorm = 0
for p in model.parameters():
if p.requires_grad and p.grad is not None:
modulenorm = p.grad.norm() #計算該參數所有梯度的L2範數
totalnorm += modulenorm ** 2
totalnorm = torch.sqrt(totalnorm).item()
norm = (clip_norm / max(totalnorm, clip_norm)) #得到梯度剪切係數
for p in model.parameters():
if p.requires_grad and p.grad is not None:
p.grad.mul_(norm)
其中,二階範數(也稱L2範數)是最常見的範數,即歐幾里得距離,表達式如下:
即,對N個數據求p範數。代碼中norm()函數參數都使用默認值,那麼就是求所有數值的2範數。
參考鏈接:https://blog.csdn.net/u010814042/article/details/76154391