循環神經網絡中比較容易出現梯度衰減或梯度爆炸,爲了應對梯度爆炸,可以進行裁剪梯度。假設把所有模型參數梯度的元素拼接成一個向量g,並設裁剪的閾值是。裁剪後的梯度的範數不超過。
通過代碼進行演示:
def grad_clipping(params, theta) #paras是模型參數,theta是閾值
norm = 0
for param in params:
norm += (param ** 2).sum()
norm = norm.sqrt()
if norm > theta:
for param in params:
param *= theta/norm