mxnet - reshape操作完全解析(理解0,-1,-2,-3,-4)

一般來說,同一個操作,mxnet的ndarry和symbol都會有,分別對應動態圖和靜態圖,比如reshape,可以調用 mx.nd.reshape,或者調用 mx.sym.reshape。下面對reshape這個操作進行解析,以mx.nd.reshape作爲參考。

reshape的註釋

reshape(data=None, shape=_Null, reverse=_Null, target_shape=_Null, keep_highest=_Null, out=None, name=None, **kwargs)
    Reshapes the input array.

    .. note:: ``Reshape`` is deprecated, use ``reshape``

    Given an array and a shape, this function returns a copy of the array in the new shape.
    The shape is a tuple of integers such as (2,3,4). The size of the new shape should be same as the size of the input array.

    Example::

      reshape([1,2,3,4], shape=(2,2)) = [[1,2], [3,4]]

    Some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. The significance of each is explained below:

    - ``0``  copy this dimension from the input to the output shape.

      Example::

      - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2)
      - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4)

    - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
      keeping the size of the new array same as that of the input array.
      At most one dimension of shape can be -1.

      Example::

      - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
      - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
      - input shape = (2,3,4), shape=(-1,), output shape = (24,)

    - ``-2`` copy all/remainder of the input dimensions to the output shape.

      Example::

      - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4)
      - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4)
      - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1)

    - ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.

      Example::

      - input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
      - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
      - input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
      - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)

    - ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).

      Example::

      - input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4)
      - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)

    If the argument `reverse` is set to 1, then the special values are inferred from right to left.

      Example::

      - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5)
      - with reverse=1, output shape will be (50,4).

reshape傳入的一個參數shape元組,元組中的數字可以非0正數,或者是0,-1,-2,-3,-4 這些奇怪的輸入,下面講講這些參數的意義。

0

0起一個佔位符的作用,默認從左到右進行佔位(除非傳入reverse=1,則從右到左),維持原數組在該位置的維度。

  • input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) # 中間維度維持不變
  • input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) # 後兩個維度維持不變

-1

-1是最後進行推導的,先保證其他數字被照顧好之後,在reshape前後數組的size不變的約束下,推導出該位置的維度。通常來說,最多隻有一個-1,但是在有 -4 的情況下,可以有兩個 -1。

  • input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
  • input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
  • input shape = (2,3,4), shape=(-1,), output shape = (24,)

-2

-2和-1不同,-2可以包括多個維度。當其他位置都有對應的維度之後,-2就來容納剩下的多個維度。

  • input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) # -2來容納所有的維度
  • input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) # 2佔據了一個維度,-2容納剩下的(3,4)
  • input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) # (1,1)是新增的兩個維度,-2將(2,3,4)給容納

-3

-3是將對應的兩個維度合成一個維度,合成之後的維度值爲之前兩個維度的乘積。

  • input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
  • input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
  • input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
  • input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)

-4

-4和-3不同,-4是將一個維度拆分爲兩個,-4後面跟兩個數字,代表拆分後的維度,其中可以有-1。

  • input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4) # 將2拆分爲1X2,剩下的3,4傳遞給-2
  • input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) # 將3拆分爲1X3,剩下的4傳遞給-2

reverse

If the argument `reverse` is set to 1, then the special values are inferred from right to left.

  Example::

  - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5)
  - with reverse=1, output shape will be (50,4).

一個例子:GN的實現

class GroupNorm(mx.gluon.HybridBlock):
    r"""Group Normalization

    refer to paper <Group Normalization>

    """
    def __init__(self,
                 in_channels,
                 groups=32,
                 gamma_initializer='ones',
                 beta_initializer='zeros',
                 **kwargs):
        super(GroupNorm, self).__init__(**kwargs)
        self.groups = min(in_channels, groups)
        assert in_channels % self.groups == 0, "Channel number should be divisible by groups."
        attrs = SpecialAttrScope.current.attrs
        self.mirroring_level = attrs.get('mirroring_level', 0)
        self.eps = attrs.get('gn_eps', 2e-5)
        self.use_fp16 = False
        with self.name_scope():
            self.gamma = self.params.get('gamma',
                                         grad_req='write',
                                         shape=(1, in_channels, 1, 1),
                                         init=gamma_initializer,
                                         allow_deferred_init=True,
                                         differentiable=True)
            self.beta = self.params.get('beta',
                                        grad_req='write',
                                        shape=(1, in_channels, 1, 1),
                                        init=beta_initializer,
                                        allow_deferred_init=True,
                                        differentiable=True)

    def cast(self, dtype):
        self.use_fp16 = False
        if np.dtype(dtype).name == 'float16':
            self.use_fp16 = True
            dtype = 'float32'
        super(GroupNorm, self).cast(dtype)

    def hybrid_forward(self, F, x, gamma, beta):
        _kwargs = {}
        if F is mx.symbol and self.mirroring_level >= 3:
            _kwargs['force_mirroring'] = 'True'

        if self.use_fp16:
            x = F.cast(data=x, dtype='float32')

        # (N, C, H, W) --> (N, G, C//G, H, W
        x = F.reshape(x, shape=(-1, -4, self.groups, -1, -2))

        # y = (x - mean) / sqrt(var + eps)
        mean = F.mean(x, axis=(2, 3, 4), keepdims=True, **_kwargs)
        y = F.broadcast_sub(x, mean, **_kwargs)
        var = F.mean(y**2, axis=(2, 3, 4), keepdims=True, **_kwargs)
        y = F.broadcast_div(y, F.sqrt(var + self.eps))

        # (N, G, C//G, H, W --> (N, C, H, W)
        y = F.reshape(y, shape=(-1, -3, -2))

        y = F.broadcast_mul(y, gamma, **_kwargs)
        y = F.broadcast_add(y, beta, **_kwargs)

        if self.use_fp16:
            y = F.cast(data=y, dtype='float16')

        return y

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章