mxnet進階 - mx.io.NDArrayIter 源碼分析

介紹

想知道mxnet在訓練過程或者驗證過程中,如何通過iterator提供數據

幾個問題:

  • 如何構造iterator?
  • 訓練或者測試時從iterator中獲取數據,data_batch = next(iterator).getdata(),輸出的data_batch是什麼?又是怎麼獲得的?
  • provide_data,provide_label如何設計以及如何應用於module或executor初始化?

分析mxnet自帶的mx.io.NDArrayIter,看如何把一個NDArray轉化爲一個可以用於module.fit() 的 iterator

用於測試的代碼,使用一個MLP學習mnist的例子

'''
Loading Data
'''
import mxnet as mx
from collections import OrderedDict
from mxnet.ndarray import array
mnist = mx.test_utils.get_mnist()# dict
#'train_data' ndarray ,shape<class 'tuple'> (60000,1,28,28)
#'train_label' ndarray ,shape<class 'tuple'> (60000,)
#'test_data' ndarray ,shape<class 'tuple'> (10000,1,28,28)
#'test_label' ndarray ,shape<class 'tuple'> (10000,)
# Fix the seed
mx.random.seed(42)

# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()


batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

'''
Training
'''

'''
這裏的名字'data'不能改,對應於mx.io.NDArrayIter的defaltname參數就是'data',往後看就明白了
也可以改着看看bug信息
'''
data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)


# The first fully-connected layer and the corresponding activation function
fc1  = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")

# The second fully-connected layer and the corresponding activation function
fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

import logging
logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter,  # train data
              eval_data=val_iter,  # validation data
              optimizer='sgd',  # use SGD to train
              optimizer_params={'learning_rate':0.1},  # use fixed learning rate
              eval_metric='acc',  # report accuracy during training
              batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
              num_epoch=10)  # train for at most 10 dataset passes

看 mx.io.NDArrayIter.__init__()

 def __init__(self, data, label=None, batch_size=1, shuffle=False,
                 last_batch_handle='pad', data_name='data',
                 label_name='softmax_label'):
        super(NDArrayIter, self).__init__(batch_size)
        '''統一輸入的格式爲list(tuple(key,val),tuple(key,val)……)'''
        '''劃重點!!!這個key和executor裏的symbol對應的'''
        self.data = _init_data(data, allow_empty=False, default_name=data_name)
        self.label = _init_data(label, allow_empty=True, default_name=label_name)

        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
                (last_batch_handle != 'discard')):
            raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
                                      " with `last_batch_handle` set to `discard`.")
        '''
        self.idx 是一個arrange生成的list,list大小爲self.data[0][1].shape[0]
        從shape[0]可以看出輸入的ndarray格式必須是num_data*data_instance
        shuffle data 打亂數據
        '''
        if shuffle:
            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
            self.data = _shuffle(self.data, self.idx)
            self.label = _shuffle(self.label, self.idx)
        else:
            self.idx = np.arange(self.data[0][1].shape[0])
        '''如果選擇'discard',則把輸入的數據裁剪爲batch_size的整數倍'''
        # batching
        if last_batch_handle == 'discard':
            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
            self.idx = self.idx[:new_n]
        '''把data和label關聯成一個list=[data_0_ndarray,data_1_ndarray,……,label_0_ndarray,]'''
        self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
        '''輸入輸出一共有多少個ndarray mnist-2 '''
        self.num_source = len(self.data_list)
        '''數據量mnist-60000 for train'''
        self.num_data = self.idx.shape[0]
        '''batch_size 不能大於總數據量'''
        assert self.num_data >= batch_size, \
            "batch_size needs to be smaller than data size."
        '''定義一個光標'''
        self.cursor = -batch_size
        self.batch_size = batch_size
        '''最後不夠一個batch_size時的處理方法'pad' or 'discard' '''
        self.last_batch_handle = last_batch_handle

_init_data 的作用是把輸入的data統一格式,因爲這個初始化data輸入支持多種類型numpy.ndarray/mxnet.ndarray/h5py.Dataset輸入,可以是單個的這些類型數據,也可能是他們的list輸入

  • 比如輸入一個mxnet.ndarray
  • 輸出的格式爲list[tuple(str{'_0_data'},mxnet.ndarray)]
  • 如果輸入一個list:[mxnet.ndarray,mxnet.ndarray]
  • 輸出格式爲list[tuple(str{'_0_data'},mxnet.ndarray),tuple(str{'_1_data'},mxnet.ndarray)]
def _init_data(data, allow_empty, default_name):
    """Convert data into canonical form."""
    assert (data is not None) or allow_empty
    if data is None:
        data = []
    '''如果輸入不是list,則把data轉化爲list,list中只有一個元素'''
    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
                  if h5py else (np.ndarray, NDArray)):
        data = [data]
    #type(data) = list
    '''接着把list轉化爲OrderedDict'''
    if isinstance(data, list):
        if not allow_empty:
            assert(len(data) > 0)
        if len(data) == 1:
            '''如果list中只有一個,即輸入只有一個ndarray,
            Dict 只有一個元素, key 命名和參數 default_name一致,val 爲該輸入data
            data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type'''
        else:
            '''輸入多個ndarray,則Dict中有多個元素,key命名格式爲('_%d_%s' % (i, default_name)
            如:{('_0_data',ndarray),('_1_data',ndarray)……}'''
            data = OrderedDict( # pylint: disable=redefined-variable-type
                [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
    if not isinstance(data, dict):
        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
                "a list of them or dict with them as values")
    '''這裏把非mxnet.ndarray輸入,轉換成mxnet.ndarray數據類型'''
    for k, v in data.items():
        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
            try:
                data[k] = array(v)
            except:
                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) + \
                                "should be NDArray, numpy.ndarray or h5py.Dataset")
    '''把Dict轉成list,dict中的(key,val)變成tuple(key,val)'''
    return list(sorted(data.items()))

初始化完之後再mod.fit()當中使用到該iter的幾個成員API:

  • mx.io.NDArrayIter.provide_data() # 用於module 或者 executor初始化
  • mx.io.NDArrayIter.provide_label() # 用於module 或者 executor初始化和上面同步
  • iter(mx.io.NDArrayIter) #得到一個迭代器,用於每次訓練獲取數據batch
  • mx.io.NDArrayIter.reset() #訓練時每個epoch 結束時 reset一次
    @property
    def provide_data(self):
        """The name and shape of data provided by this iterator."""
        return [
            """
            DataDesc 是一個namedtuple,不知道啥的百度去……
            return DataDesc對象,初始化該對象使用了兩個信息,首先看self.data結構
            self.data - list[tuple1(name1,val1),tuple2(name2,val2)……]
            其中name是symbol中的輸入參數的name,上述初始化iterator時指定的
            val 是該參數的數據矩陣n*v
            DataDesc對象初始化時使用到 str(name) 和 tuple(batch_size,v)"""
            DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
            for k, v in self.data
        ]

Module / Executor 初始化時用到mx.io.NDArrayIter

mxnet.Module.fit().Module.bind()使用到了 mx.io.NDArrayIter.provide_data() mx.io.NDArrayIter.provide_label()

self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)

'''解析數據描述子'''
    self._data_shapes, self._label_shapes = _parse_data_desc(
            self.data_names=['data'], 
            self.label_names=['Softmaxlabel'], 
            data_shapes=mx.io.NDArrayIter.provide_data(), 
            label_shapes=mx.io.NDArrayIter.provide_label())
'''這個數據解析器幹兩件事:
        把data attributes 轉成DataDesc 格式
        檢查輸入的'數據屬性表中的名字'是否和'網絡輸入symbol的名字'匹配
'''
def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
    """parse data_attrs into DataDesc format and check that names match"""
    data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
    _check_names_match(data_names, data_shapes, 'data', True)
    if label_shapes is not None:
        label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
        _check_names_match(label_names, label_shapes, 'label', False)
    else:
        _check_names_match(label_names, [], 'label', False)
    return data_shapes, label_shapes

'''然後得到的數據描述子用於Executor group類的初始化'''
        self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
                                                     self._work_load_list, self._data_shapes,
                                                     self._label_shapes, self._param_names,
                                                     for_training, inputs_need_grad,
                                                     shared_group, logger=self.logger,
                                                     fixed_param_names=self._fixed_param_names,
                                                     grad_req=grad_req, group2ctxs=self._group2ctxs,
                                                     state_names=self._state_names)

    '''初始化裏在最後一行代碼用到這些數據描述子'''#shared_group =None
        self.bind_exec(data_shapes, label_shapes, shared_group)
    '''繼續搜,上面的處理是爲了多GPU並行處理而設定的,在這裏把每一個GPU負責的batch,分給每一個Executor,並把這些Executor收集起來'''
        self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
                                                      shared_group))
    '''繼續 這裏是Module.bind()的終點,通過simple_bind得到一個Executor'''
    def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
        """Internal utility function to bind the i-th executor.
        This function utilizes simple_bind python interface.
        """
        shared_exec = None if shared_group is None else shared_group.execs[i]
        context = self.contexts[i]
        shared_data_arrays = self.shared_data_arrays[i]

        input_shapes = dict(data_shapes)
        if label_shapes is not None:
            input_shapes.update(dict(label_shapes))
        '''這裏通過輸入的data descriptor 得到一個字典 用於初始化 executor'''
        input_types = {x.name: x.dtype for x in data_shapes}
        if label_shapes is not None:
            input_types.update({x.name: x.dtype for x in label_shapes})

        group2ctx = self.group2ctxs[i]
        '''simple_bind
        後面開另外的blog仔細研究
        目前推測
        這裏按照輸入數據的描述子計算網絡的靜態圖
        並在對應的context上分配對應的空間'''
        executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
                                           type_dict=input_types, shared_arg_names=self.param_names,
                                           shared_exec=shared_exec, group2ctx=group2ctx,
                                           shared_buffer=shared_data_arrays, **input_shapes)
        self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
        return executor
    

Module / Executor 訓練時用到mx.io.NDArrayIter

這裏重點關注每個epoch中,fit()如何調用mx.io.NDArrayIter提供數據用於訓練的

for epoch in range(begin_epoch, num_epoch):
    tic = time.time()
    eval_metric.reset()
    nbatch = 0
    '''iter()是用於把一個可迭代的非iter對象變成迭代器,如list
    本blog中對train_data沒有任何改變
    因爲傳入的train_data本身就是一個迭代器
    
    '''
    data_iter = iter(train_data)
    '''
    print(type(data_iter ),type(train_iter))
    <class 'mxnet.io.NDArrayIter'> <class 'mxnet.io.NDArrayIter'>
    '''
    '''初始化一個標識,用於檢測iter是否到尾了'''
    end_of_batch = False
    '''獲得一個batch大小的data,<class 'mxnet.io.DataBatch'>
    DataBatch 下有這兩個重要的成員變量
        data <class 'list<class mx.NDArray>'>
        label <class 'list<class mx.NDArray>'>
    '''
    next_data_batch = next(data_iter)
    '''
    def next(self):
        #調用iter_next()判斷迭代器是否到尾了
        if self.iter_next():
            ''''''
            return DataBatch(data=self.getdata(), label=self.getlabel(), \
                    pad=self.getpad(), index=None)
        else:
            raise StopIteration

    #self.cursor 初始化爲-self.batch_size
    #這是爲了在取數據時用到的cursor指向需要取的數據,如第一次next().getdata()時cursor = 0
    #每次調用next()則自加self.batch_size
    #如果記錄值小於data的長度,則返回真,否則返回假
    def iter_next(self):
        self.cursor += self.batch_size
        return self.cursor < self.num_data
    
    '''
    
    while not end_of_batch:
        data_batch = next_data_batch
        if monitor is not None:
            monitor.tic()
        self.forward_backward(data_batch)
        self.update()
        '''處理完參數更新後,獲取新的batch'''
        try:
            # pre fetch next batch
            next_data_batch = next(data_iter)
            self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
        '''except 和 raise對應'''
        except StopIteration:
            end_of_batch = True

重點看看iter.next().getdata()

    def _getdata(self, data_source):
        """Load data from underlying arrays, internal use only."""
        assert(self.cursor < self.num_data), "DataIter needs reset."
        '''判斷iter中剩下的數據是否夠一個batch'''
        if self.cursor + self.batch_size <= self.num_data:
            return [
                # np.ndarray or NDArray case
                '''data_source = self.data <class 'list<tuple<str_name,ndarray_data>>'>
                    取self.data list中所有成員中對應[cursor:cursor+data_batch]區間的數據
                '''
                x[1][self.cursor:self.cursor + self.batch_size]
                if isinstance(x[1], (np.ndarray, NDArray)) else
                # h5py (only supports indices in increasing order)
                array(x[1][sorted(self.idx[
                    self.cursor:self.cursor + self.batch_size])][[
                        list(self.idx[self.cursor:
                                      self.cursor + self.batch_size]).index(i)
                        for i in sorted(self.idx[
                            self.cursor:self.cursor + self.batch_size])
                    ]]) for x in data_source
            ]
        else:
            pad = self.batch_size - self.num_data + self.cursor
            return [
                # np.ndarray or NDArray case
                concatenate([x[1][self.cursor:], x[1][:pad]])
                if isinstance(x[1], (np.ndarray, NDArray)) else
                # h5py (only supports indices in increasing order)
                concatenate([
                    array(x[1][sorted(self.idx[self.cursor:])][[
                        list(self.idx[self.cursor:]).index(i)
                        for i in sorted(self.idx[self.cursor:])
                    ]]),
                    array(x[1][sorted(self.idx[:pad])][[
                        list(self.idx[:pad]).index(i)
                        for i in sorted(self.idx[:pad])
                    ]])
                ]) for x in data_source
            ]

總結

  • DataIter初始化時給定數據和數據名稱要和網絡輸入的symbol名稱對應
  • DataIter用於網絡初始化和網絡訓練測試
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章