【人工智能筆記】第十三節 Tensorflow 2.0 下自定義Optimizer,實現 Adax Optimizer

關鍵方法

  • _create_slots:爲每個待更新變量創建用於計算的關聯變量。
  • _resource_apply_dense與_resource_apply_sparse:每層梯度更新都會調用該方法,返回更新變量操作。

Adax Optimizer實現代碼如下: 

import tensorflow as tf


class AdaX(tf.keras.optimizers.Optimizer):
    r"""Implements AdaX algorithm.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 1e-4))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-12)
        weight_decay (float, optional): L2 penalty (default: 5e-4)
    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(
        self,
        learning_rate=0.001,
        beta_1=0.9,
        beta_2=0.0001,
        epsilon=1e-6,
        **kwargs
    ):
        kwargs['name'] = kwargs.get('name') or 'AdaX_V2'
        super(AdaX, self).__init__(**kwargs)
        self._set_hyper('learning_rate', learning_rate)
        self._set_hyper('beta_1', beta_1)
        self._set_hyper('beta_2', beta_2)
        self.epsilon = epsilon
        print('self._initial_decay:', self._initial_decay)

    def _create_slots(self, var_list):
        '''
        給變量創建關聯變量,用於梯度計算
        var_list:可更新的變量列表
        '''
        tf.print('var_list:', type(var_list))
        for var in var_list:
            self.add_slot(var, 'm')
            self.add_slot(var, 'v')

    def _resource_apply(self, grad, var, indices=None):
        '''每層梯度更新的計算公式'''
        # 準備變量
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        epsilon_t = tf.cast(self.epsilon, var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)

        # 更新公式
        if indices is None:
            m_t = m.assign(beta_1_t * m + (1 - beta_1_t) * grad)
            v_t = v.assign((1 + beta_2_t) * v + beta_2_t * grad**2)
        else:
            mv_ops = [
                m.assign(beta_1_t * m),
                v.assign((1 + beta_2_t) * v)
            ]
            with tf.control_dependencies(mv_ops):
                m_t = self._resource_scatter_add(
                    m, indices, (1 - beta_1_t) * grad
                )
                v_t = self._resource_scatter_add(
                    v, indices, beta_2_t * grad**2)

        # 返回算子
        # tf.control_dependencies先執行前置操作,後執行內部代碼
        with tf.control_dependencies([m_t, v_t]):
            v_t = v_t / (tf.pow(1.0 + beta_2_t, local_step) - 1.0)
            var_t = var.assign(var - lr_t * m_t / (tf.sqrt(v_t) + self.epsilon))
            return var_t

    def _resource_apply_dense(self, grad, var):
        '''每層梯度跟新都會調用該方法'''
        return self._resource_apply(grad, var)

    def _resource_apply_sparse(self, grad, var, indices):
        '''每層梯度跟新都會調用該方法'''
        return self._resource_apply(grad, var, indices)

    def get_config(self):
        tf.print('get_config')
        config = {
            'learning_rate': self._serialize_hyperparameter('learning_rate'),
            'beta_1': self._serialize_hyperparameter('beta_1'),
            'beta_2': self._serialize_hyperparameter('beta_2'),
            'epsilon': self.epsilon,
        }
        base_config = super(AdaX, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

參考資料:https://github.com/bojone/adax 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章