Batch Normalization導數計算與代碼實現

關於batch normalization的論文閱讀可以參考以前的一篇博文Batch Normalization,這裏主要對BN的導數進行推導,並且看一下tensorflow中的源碼實現。

BN導數推導

正向計算

回顧一下Batch Normalization的正向計算公式:

\mu=\frac{1}{m}\sum_{i=1}^{m}{x_i}                        (1)

\sigma^2=\frac{1}{m}\sum_{i=1}^{m}{(x_i - \mu)^2}          (2)

\hat{x_i} = \frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}}                      (3)

\hat{y_i} = \gamma \hat{x_i}+\beta                        (4)

自己實現這個功能代碼如下:

def batchnorm_forward(x, gamma, beta, bn_param):
    """
    Forward pass for batch normalization.

    During training the sample mean and (uncorrected) sample variance are
    computed from minibatch statistics and used to normalize the incoming data.
    During training we also keep an exponentially decaying running mean of the
    mean and variance of each feature, and these averages are used to normalize
    data at test-time.

    At each timestep we update the running averages for mean and variance using
    an exponential decay based on the momentum parameter:

    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

    Note that the batch normalization paper suggests a different test-time
    behavior: they compute sample mean and variance for each feature using a
    large number of training images rather than using a running average. For
    this implementation we have chosen to use running averages instead since
    they do not require an additional estimation step; the torch7
    implementation of batch normalization also uses running averages.

    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features

    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

    out, cache = None, None
    if mode == 'train':
        #######################################################################
        # TODO: Implement the training-time forward pass for batch norm.      #
        # Use minibatch statistics to compute the mean and variance, use      #
        # these statistics to normalize the incoming data, and scale and      #
        # shift the normalized data using gamma and beta.                     #
        #                                                                     #
        # You should store the output in the variable out. Any intermediates  #
        # that you need for the backward pass should be stored in the cache   #
        # variable.                                                           #
        #                                                                     #
        # You should also use your computed sample mean and variance together #
        # with the momentum variable to update the running mean and running   #
        # variance, storing your result in the running_mean and running_var   #
        # variables.                                                          #
        #                                                                     #
        # Note that though you should be keeping track of the running         #
        # variance, you should normalize the data based on the standard       #
        # deviation (square root of variance) instead!                        # 
        # Referencing the original paper (https://arxiv.org/abs/1502.03167)   #
        # might prove to be helpful.                                          #
        #######################################################################
        pass
        mean = np.sum(x, axis=0) / N
        var = np.sum((x - mean) ** 2, axis=0) / N
        x_hat = (x - mean) / np.sqrt(var + eps)
        out = gamma * x_hat + beta
        cache = {}
        cache['x_hat'] = x_hat
        cache['gamma'] = gamma
        cache['var_eps'] = var + eps
        cache['x_norm'] = x - mean
        running_mean = momentum * running_mean + (1 - momentum) * mean
        running_var = momentum * running_var + (1 - momentum) * var
        #######################################################################
        #                           END OF YOUR CODE                          #
        #######################################################################
    elif mode == 'test':
        #######################################################################
        # TODO: Implement the test-time forward pass for batch normalization. #
        # Use the running mean and variance to normalize the incoming data,   #
        # then scale and shift the normalized data using gamma and beta.      #
        # Store the result in the out variable.                               #
        #######################################################################
        pass
        x_hat = (x - running_mean) / np.sqrt(running_var + eps)
        out = gamma * x_hat + beta
        #######################################################################
        #                          END OF YOUR CODE                           #
        #######################################################################
    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

    # Store the updated running means back into bn_param
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

代碼中有幾個點需要注意:

  1. 我們需要區分訓練和測試階段,因爲訓練時mean和var是用當前數據計算的,但是測試時我們需要使用running_mean和running_var來計算。
  2. running_mean和running_var是用每次的mean和var加上動量計算出來的,在數據測試時使用。
  3. 需要cache很多數據,爲了在求導時使用。
  4. m表示樣本個數,c表示feature個數,所以x的尺寸按照cs231那種的表示習慣爲(m,c)\gamma的尺寸爲(1, c)\gamma的會與每個x_i相乘所以尺寸爲(1, c)\beta尺寸爲(1, c)y的尺寸爲(m,c),因爲我們只是用batch normaliaztion做分佈變換,所以變換後尺寸是不變的。

反向計算

因爲在正向傳播中的計算都是矩陣元素計算,不是點積,所以我們基本不用考慮矩陣轉置的問題,而需要考慮的是何時求和的問題。

反向傳播中我們已知cache中保存的信息和損失函數對y的導數dout(是對y_1,y_2,......,y_m全部求導的矩陣),需要求三個導數分別是d\gammad\betadx(雖然這裏實際上是損失函數對他們求偏導,但是爲了表示方便,就直接寫成d\gamma)。

首先求d\gamma

因爲最後的損失函數是每個y_i相關函數的求和,所以對\gamma求導需要進行求和

cost = \sum_{i=1}^{m}{L(y_i)}

d\gamma =\sum_{i=1}^{m}{dy_i * \frac{\partial y_i}{\partial \gamma}} = \sum_{i=1}^{m}{dy_i*x_i}=np.sum(dout*x, axis=0)                   (5)

接着求 d\beta

     d\beta =\sum_{i=1}^{m}{dy_i * \frac{\partial y_i}{\partial \beta}} = \sum_{i=1}^{m}{dy_i}=np.sum(dout, axis=0)                          (6)

最後求dx

因爲x_i\mu,\sigma^2\hat{x_i}相關,所以寫出下面的導數

dx_i=d\mu*\frac{\partial \mu}{\partial x_i} + d\sigma^2 *\frac{\partial \sigma^2}{\partial x_i} + d\hat{x_i}*\frac{\partial \hat{x_i}}{\partial x_i}=d\mu *\frac{1}{m}+d\sigma^2*\frac{2(x_i - \mu)}{m} + d\hat{x_i}*\frac{1}{\sqrt{\sigma^2+\epsilon}}                  (7)

所以接下來的任務就是要求出d\mu,d\sigma^2,d\hat{x_i}

d\hat{x_i}=dy_i*\gamma               (8)

因爲\sigma^2只與\hat{x_i}相關,並且\sigma^2與每個\hat{x_i}參與計算後求和纔是最後的cost函數,所以可以推導公式如下:

d\sigma^2=\sum_{i=1}^{m}{d\hat{x_i}*\frac{\partial \hat{x_i}}{\partial \sigma^2}}=\sum_{i=1}^{m}{d\hat{x_i}*(\frac{-1}{2}*(x_i-\mu) * (\sigma^2+\epsilon)^{\frac{-3}{2}})}                   (9)

\mu\hat{x_i},\sigma^2有關,所以可以推導公式如下:

d\mu=\sum_{i=1}^{m}{d\hat{x_i}*\frac{\partial \hat{x_i}}{\partial \mu}} + d\sigma^2*\frac{\partial \sigma^2}{\partial \mu}=\sum_{i=1}^{m}{d\hat{x_i}*\frac{-1}{\sqrt{\sigma^2+\epsilon}}}+ d\sigma^2*\frac{-2}{m}\sum_{i=1}^{m}{(x_i-\mu)}                                   (10)

將8,9,10式子互相帶入再帶入7就可以得到dx
 

自己實現代碼如下:

def batchnorm_backward(dout, cache):
    """
    Backward pass for batch normalization.

    For this implementation, you should write out a computation graph for
    batch normalization on paper and propagate gradients backward through
    intermediate nodes.

    Inputs:
    - dout: Upstream derivatives, of shape (N, D)
    - cache: Variable of intermediates from batchnorm_forward.

    Returns a tuple of:
    - dx: Gradient with respect to inputs x, of shape (N, D)
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
    """
    dx, dgamma, dbeta = None, None, None
    ###########################################################################
    # TODO: Implement the backward pass for batch normalization. Store the    #
    # results in the dx, dgamma, and dbeta variables.                         #
    # Referencing the original paper (https://arxiv.org/abs/1502.03167)       #
    # might prove to be helpful.                                              #
    ###########################################################################
    pass
    x_hat = cache['x_hat']
    gamma = cache['gamma']
    var_eps = cache['var_eps']
    x_norm = cache['x_norm']
    N = dout.shape[0]
    dgamma = np.sum(dout * x_hat, axis=0)
    dbeta = np.sum(dout, axis=0)
    
    dx_hat = dout * gamma
    dvar = np.sum(dx_hat * (x_norm * (-1 / 2) * (var_eps ** (-3 / 2))), axis=0)
    dmean = np.sum(dx_hat * (-1 / np.sqrt(var_eps)), axis=0) + dvar * (-2) * np.sum(x_norm, axis=0) / N
    
    dx = dx_hat / np.sqrt(var_eps) + dmean / N + dvar * 2 * x_norm / N
    
    ###########################################################################
    #                             END OF YOUR CODE                            #
    ###########################################################################

    return dx, dgamma, dbeta

論文中提供的公式如下:

           

公式簡化

但是dx_i求導公式其實還可以簡化,因爲

\hat{x_i} = \frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}},所以d\sigma^2可以簡化如下:

d\sigma^2=\sum_{i=1}^{m}{d\hat{x_i}*(\frac{-1}{2}*\frac{\hat{x_i}}{\sigma^2+\epsilon})}=\frac{-1}{2*(\sigma^2+\epsilon)}*\sum_{i=1}^{m}{d\hat{x_i}*\hat{x_i}}

\sum_{i=1}^{m}{(x_i-\mu)}是近似於0的,所以我們可以簡化d\mu

d\mu = \frac{-1}{\sqrt{\sigma^2+\epsilon}}*\sum_{i=1}^{m}{d\hat{x_i}}

所以簡化後的d\mud\sigma^2帶入dx_i

dx_i=\frac{-1}{\sqrt{\sigma^2+\epsilon}}*\sum_{i=1}^{m}{d\hat{x_i}} *\frac{1}{m}+\frac{2(x_i - \mu)}{m}*\frac{-1}{2*(\sigma^2+\epsilon)}*\sum_{i=1}^{m}{d\hat{x_i}*\hat{x_i}} + d\hat{x_i}*\frac{1}{\sqrt{\sigma^2+\epsilon}}

=\frac{1}{\sqrt{\sigma^2+\epsilon}} * (-\sum_{i=1}^{m}{d\hat{x_i}} *\frac{1}{m}+\frac{(x_i-\mu)}{m*\sqrt{\sigma^2 + \epsilon}}*\sum_{i=1}^{m}{d\hat{x_i}*\hat{x_i}} + d\hat{x_i})

=\frac{1}{m*\sqrt{\sigma^2+\epsilon}} * (-\sum_{i=1}^{m}{d\hat{x_i}} +\hat{x_i}*\sum_{i=1}^{m}{d\hat{x_i}*\hat{x_i}} + m*d\hat{x_i})

代碼如下:

def batchnorm_backward_alt(dout, cache):
    """
    Alternative backward pass for batch normalization.

    For this implementation you should work out the derivatives for the batch
    normalizaton backward pass on paper and simplify as much as possible. You
    should be able to derive a simple expression for the backward pass. 
    See the jupyter notebook for more hints.
     
    Note: This implementation should expect to receive the same cache variable
    as batchnorm_backward, but might not use all of the values in the cache.

    Inputs / outputs: Same as batchnorm_backward
    """
    dx, dgamma, dbeta = None, None, None
    ###########################################################################
    # TODO: Implement the backward pass for batch normalization. Store the    #
    # results in the dx, dgamma, and dbeta variables.                         #
    #                                                                         #
    # After computing the gradient with respect to the centered inputs, you   #
    # should be able to compute gradients with respect to the inputs in a     #
    # single statement; our implementation fits on a single 80-character line.#
    ###########################################################################
    pass
    x_hat = cache['x_hat']
    gamma = cache['gamma']
    var_eps = cache['var_eps']
    N = dout.shape[0]
    dgamma = np.sum(dout * x_hat, axis=0)
    dbeta = np.sum(dout, axis=0)
    
    dx_hat = dout * gamma
    
    dx = 1 / (N * np.sqrt(var_eps)) * (dx_hat * N - np.sum(dx_hat, axis=0) - x_hat * np.sum(dx_hat * x_hat, axis=0))
    ###########################################################################
    #                             END OF YOUR CODE                          #
    ###########################################################################

    return dx, dgamma, dbeta

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章