mxnet進階 - mxnet.io.DataDesc 源碼分析

全稱應該叫Data Descriptor 數據描述子

一個namedtuple的子類

本質也是一個tuple,只不過對不同位置的元素命名了而已,可以通過命名訪問到元素,和ordinaried字典類似

mxnet.DataDesc 使用了兩個信息描述數據

  • name 是名字 字符串 str
  • shape 是數據形狀 元組 裏面存int類型的數 
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
    """DataDesc is used to store name, shape, type and layout
    information of the data or the label.

    The `layout` describes how the axes in `shape` should be interpreted,
    for example for image data setting `layout=NCHW` indicates
    that the first axis is number of examples in the batch(N),
    C is number of channels, H is the height and W is the width of the image.

    For sequential data, by default `layout` is set to ``NTC``, where
    N is number of examples in the batch, T the temporal axis representing time
    and C is the number of channels.

    Parameters
    ----------
    cls : DataDesc
         The class.
    name : str
         Data name.
    shape : tuple of int
         Data shape.
    dtype : np.dtype, optional
         Data type.
    layout : str, optional
         Data layout.
    """
    def __new__(cls, name, shape, dtype=mx_real_t, layout='NCHW'): # pylint: disable=super-on-old-class
        ret = super(cls, DataDesc).__new__(cls, name, shape)
        ret.dtype = dtype
        ret.layout = layout
        return ret

    def __repr__(self):
        return "DataDesc[%s,%s,%s,%s]" % (self.name, self.shape, self.dtype,
                                          self.layout)

    @staticmethod
    def get_batch_axis(layout):
        """Get the dimension that corresponds to the batch size.

        When data parallelism is used, the data will be automatically split and
        concatenated along the batch-size dimension. Axis can be -1, which means
        the whole array will be copied for each data-parallelism device.

        Parameters
        ----------
        layout : str
            layout string. For example, "NCHW".

        Returns
        -------
        int
            An axis indicating the batch_size dimension.
        """
        if layout is None:
            return 0
        return layout.find('N')

    @staticmethod
    def get_list(shapes, types):
        """Get DataDesc list from attribute lists.

        Parameters
        ----------
        shapes : a tuple of (name, shape)
        types : a tuple of  (name, type)
        """
        if types is not None:
            type_dict = dict(types)
            return [DataDesc(x[0], x[1], type_dict[x[0]]) for x in shapes]
        else:
            return [DataDesc(x[0], x[1]) for x in shapes]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章