AdamW, LAMB: 大型預訓練模型常用優化器

前言

按照時間上的迭代順序,近些年神經網絡先後出現了 Gradient Descent (GD)、Momentum、Adaptive Gradient (AdaGrad)、Root Mean Square prop (RMSprop)、Adaptive Moment estimation (Adam) 等優秀的優化器。到如今,大部分 NLP 預訓練模型已不再使用這些方法,而是使用 Adam Weight Decay Regularization (AdamW) 和去年首度亮相的 Layer-wise Adaptive Moments optimizer for Batching training (LAMB)。爲何最爲傳統的 GD,包括衍生的 stochastic GD、mini-batch GD 優化器已不再使用,下文會有詳細的介紹。

Gradient Descent (GD)

梯度下降法是最爲經典的凸優化優化器,思想也非常明確:通過 loss 反向傳導計算參數的梯度,參數往哪個方向跑可以讓 loss 下降,就讓參數往哪個方向更新:
ΔWk=lossWk=lossZnZnZn1...Zk+1Wk\Delta W_k=\frac{\partial loss}{\partial W_k}=\frac{\partial loss}{\partial Z_n}\frac{\partial Z_n}{\partial Z_{n-1}}...\frac{\partial Z_{k+1}}{\partial W_k}

WkWkαΔWkW_k\leftarrow W_k-\alpha \Delta W_k

需要注意的是,WkW_k 中的每一個浮點元素的梯度計算和梯度更新,相互之間是完全獨立的,這對於理解梯度更新的機理非常重要。上式中,α\alpha 爲學習率,通常是一個固定的超參數,學習率越高,收斂越快。但需要注意控制範圍。學習率過大,容易造成梯度跨過參數的局部最優點造成參數震盪;學習率過小,會導致訓練過程過於漫長。爲避免參數震盪,使用 GD 時,學習率通常設置在一個較低值,且訓練的 batch_size 越大,學習率越低。梯度裁剪雖能一定程度上解決梯度震盪的問題,但由於輸出的概率分佈發生偏移,模型收斂也受到一定負面影響,因此需儘可能避免對梯度裁剪的依賴。

Adaptive Moment estimation (Adam)

爲解決 GD 中固定學習率帶來的不同參數間收斂速度不一致的弊端,AdaGrad 和 RMSprop 誕生出來,爲每個參數賦予獨立的學習率。計算梯度後,梯度較大的參數獲得的學習率較低,反之亦然。此外,爲避免每次梯度更新時都獨立計算梯度,導致梯度方向持續變化,Momentum 將上一輪梯度值加入到當前梯度的計算中,通過某種權重對兩者加權求和,獲得當前批次參數更新的更新值。 Adam 結合了這兩項考慮,既爲每一個浮點參數自適應性地設置學習率,又將過去的梯度歷史納入考量:
mt=β1mt1+(1β1)ΔWm_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W

vt=β2vt1+(1β2)ΔW2v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2

mt^=mt1β1t\hat{m_t}=\frac{m_t}{1-\beta_1^t}

vt^=vt1β2t\hat{v_t}=\frac{v_t}{1-\beta_2^t}

WtWt1αvt^+ϵmt^W_t\leftarrow W_{t-1}-\frac{\alpha}{\sqrt{\hat{v_t}}+\epsilon}\hat{m_t}

實際使用中,通常 β1=0.9\beta_1=0.9β2>0.9\beta_2>0.9。BERT 源代碼中,預訓練的 β2\beta_2 爲 0.98,微調的 β2\beta_2 爲 0.999,其目的是爲了減少對預訓練中得到的原始參數結構的破壞,使收斂更爲平緩。此外,m0m_0v0v_0 皆爲初始化得來,因此訓練時參數種子的設置往往對模型結果的影響較大。從上述公式可以看出,訓練前期的學習率和梯度更新是比較激進的,到後期逐漸平穩。

雖然 Adam 優化器的使用會導致內存中多出兩倍於原參數體量的佔用,但與之換來的訓練收益使得學術界並沒有放棄這一高效的方法。

Adam Weight Decay Regularization (AdamW)

Adam 雖然收斂速度快,但沒能解決參數過擬合的問題。學術界討論了諸多方案,其中包括在損失函數中引入參數的 L2 正則項。這樣的方法在其他的優化器中或許有效,但會因爲 Adam 中自適應學習率的存在而對使用 Adam 優化器的模型失效。AdamW 的出現便是爲了解決這一問題,達到同樣使參數接近於 0 的目的。具體的舉措,是在最終的參數更新時引入參數自身:
mt=β1mt1+(1β1)ΔWm_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W

vt=β2vt1+(1β2)ΔW2v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2

mt^=mt1β1t\hat{m_t}=\frac{m_t}{1-\beta_1^t}

vt^=vt1β2t\hat{v_t}=\frac{v_t}{1-\beta_2^t}

WtWt1α(mt^vt^+ϵ+λWt1)W_t\leftarrow W_{t-1}-\alpha\big(\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon}+\lambda W_{t-1}\big)

λ\lambda 即爲權重衰減因子,常見的設置爲 0.005/0.01。這一優化策略目前正廣泛應用於各大預訓練語言模型。

Layer-wise Adaptive Moments optimizer for Batching training (LAMB)

LAMB 優化器是 2019 年出現的一匹新秀,原論文標題後半部分叫做 “Training BERT in 76 Minutes”,足以看出其野心之大。 LAMB 出現的目的是加速預訓練進程,這個優化器也成爲 NLP 社區爲泛機器學習領域做出的一大貢獻。在使用 Adam 和 AdamW 等優化器時,一大問題在於 batch size 存在一定的隱式上限,一旦突破這個上限,梯度更新極端的取值會導致自適應學習率調整後極爲困難的收斂,從而無法享受增加的 batch size 帶來的提速增益。LAMB 優化器的作用便在於使模型在進行大批量數據訓練時,能夠維持梯度更新的精度:
mt=β1mt1+(1β1)ΔWm_t=\beta_1m_{t-1}+(1-\beta_1)\Delta W

vt=β2vt1+(1β2)ΔW2v_t=\beta_2v_{t-1}+(1-\beta_2)\Delta W^2

rt=mtvt+ϵr_t=\frac{m_t}{\sqrt{v_t}+\epsilon}

WtWt1αϕ(Wt1rt+λWt1)(rt+λWt1)W_t\leftarrow W_{t-1}-\alpha\cdot\phi\big(\frac{||W_{t-1}||}{||r_t+\lambda W_{t-1}||}\big)(r_t+\lambda W_{t-1})

其中,ϕ\phi 是一個可選擇的映射函數,一種是 ϕ(z)=z\phi(z)=z,另一種則爲起到歸一化作用的 ϕ(z)=min(max(z,γl),γu)\phi(z)=\min(\max(z, \gamma_l),\gamma_u)γl\gamma_lγu\gamma_u 爲預先設定的超參數,分別代表參數調整的下界和上界。這一簡單的調整所帶來的實際效果非常顯著。使用 AdamW 時,batch size 超過 512 便會導致模型效果大幅下降,但在 LAMB 下,batch size 可以直接提到 32,000 而不會導致精度損失。

由於在下游微調預訓練模型時,通常無需過大的數據集,因而 LAMB 僅在預訓練環節使用。遺憾的是,LAMB 在 batch size 512 以下時無法起到顯著作用,目前只能作爲大體量財團的工具。

附錄

以下是 LAMB 優化器的 tensorflow1.x 代碼,可作爲參考以理解算法,具體的代碼出處已無法找尋。

class LAMBOptimizer(tf.train.Optimizer):
    '''
    LAMBOptimizer optimizer.
	
	# Important Note
		- This is NOT an official implementation.
		- LAMB optimizer is changed from arXiv v1 ~ v3.
		- We implement v3 version (which is the latest version on June, 2019.).
		- Our implementation is based on `AdamWeightDecayOptimizer` in BERT (provided by Google).
    # References
		- LAMB optimier: https://github.com/ymcui/LAMB_Optimizer_TF
		- Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962v3
		- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805
    # Parameters
		- There is nothing special, just the same as `AdamWeightDecayOptimizer`.
    '''
    def __init__(self,
                 learning_rate,
                 weight_decay_rate=0.01,
                 beta_1=0.9,
                 beta_2=0.999,
                 epsilon=1e-6,
                 exclude_from_weight_decay=None,
                 name="LAMBOptimizer"):
        """Constructs a LAMBOptimizer."""
        super(LAMBOptimizer, self).__init__(False, name)

        self.learning_rate = learning_rate
        self.weight_decay_rate = weight_decay_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.exclude_from_weight_decay = exclude_from_weight_decay

    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """See base class."""
        assignments = []
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue

            param_name = self._get_variable_name(param.name)

            m = tf.get_variable(
                name=param_name + "/lamb_m",
                shape=param.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
            v = tf.get_variable(
                name=param_name + "/lamb_v",
                shape=param.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())

            # Standard Adam update.
            next_m = (
                    tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
            next_v = (
                    tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                              tf.square(grad)))

            update = next_m / (tf.sqrt(next_v) + self.epsilon)

            # Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want ot decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            if self._do_use_weight_decay(param_name):
                update += self.weight_decay_rate * param

            ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ##############

            # Note: Here are two choices for scaling function \phi(z)
            # minmax:   \phi(z) = min(max(z, \gamma_l), \gamma_u)
            # identity: \phi(z) = z
            # The authors does not mention what is \gamma_l and \gamma_u
            # UPDATE: after asking authors, they provide me the code below.
            # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
            #      math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)

            r1 = tf.sqrt(tf.reduce_sum(tf.square(param)))
            r2 = tf.sqrt(tf.reduce_sum(tf.square(update)))

            r = tf.where(tf.greater(r1, 0.0),
                         tf.where(tf.greater(r2, 0.0),
                                  r1 / r2,
                                  1.0),
                         1.0)

            eta = self.learning_rate * r

            update_with_lr = eta * update

            next_param = param - update_with_lr

            assignments.extend(
                [param.assign(next_param),
                 m.assign(next_m),
                 v.assign(next_v)])
        return tf.group(*assignments, name=name)

    def _do_use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay_rate:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _get_variable_name(self, param_name):
        """Get the variable name from the tensor name."""
        m = re.match("^(.*):\\d+$", param_name)
        if m is not None:
            param_name = m.group(1)
        return param_name
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章