介紹
想知道mxnet在訓練過程或者驗證過程中,如何通過iterator提供數據
幾個問題:
- 如何構造iterator?
- 訓練或者測試時從iterator中獲取數據,data_batch = next(iterator).getdata(),輸出的data_batch是什麼?又是怎麼獲得的?
- provide_data,provide_label如何設計以及如何應用於module或executor初始化?
分析mxnet自帶的mx.io.NDArrayIter,看如何把一個NDArray轉化爲一個可以用於module.fit() 的 iterator
用於測試的代碼,使用一個MLP學習mnist的例子
'''
Loading Data
'''
import mxnet as mx
from collections import OrderedDict
from mxnet.ndarray import array
mnist = mx.test_utils.get_mnist()# dict
#'train_data' ndarray ,shape<class 'tuple'> (60000,1,28,28)
#'train_label' ndarray ,shape<class 'tuple'> (60000,)
#'test_data' ndarray ,shape<class 'tuple'> (10000,1,28,28)
#'test_label' ndarray ,shape<class 'tuple'> (10000,)
# Fix the seed
mx.random.seed(42)
# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
'''
Training
'''
'''
這裏的名字'data'不能改,對應於mx.io.NDArrayIter的defaltname參數就是'data',往後看就明白了
也可以改着看看bug信息
'''
data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)
# The first fully-connected layer and the corresponding activation function
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
# The second fully-connected layer and the corresponding activation function
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
import logging
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter, # train data
eval_data=val_iter, # validation data
optimizer='sgd', # use SGD to train
optimizer_params={'learning_rate':0.1}, # use fixed learning rate
eval_metric='acc', # report accuracy during training
batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
num_epoch=10) # train for at most 10 dataset passes
看 mx.io.NDArrayIter.__init__()
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
'''統一輸入的格式爲list(tuple(key,val),tuple(key,val)……)'''
'''劃重點!!!這個key和executor裏的symbol對應的'''
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)
if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")
'''
self.idx 是一個arrange生成的list,list大小爲self.data[0][1].shape[0]
從shape[0]可以看出輸入的ndarray格式必須是num_data*data_instance
shuffle data 打亂數據
'''
if shuffle:
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
self.data = _shuffle(self.data, self.idx)
self.label = _shuffle(self.label, self.idx)
else:
self.idx = np.arange(self.data[0][1].shape[0])
'''如果選擇'discard',則把輸入的數據裁剪爲batch_size的整數倍'''
# batching
if last_batch_handle == 'discard':
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
self.idx = self.idx[:new_n]
'''把data和label關聯成一個list=[data_0_ndarray,data_1_ndarray,……,label_0_ndarray,]'''
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
'''輸入輸出一共有多少個ndarray mnist-2 '''
self.num_source = len(self.data_list)
'''數據量mnist-60000 for train'''
self.num_data = self.idx.shape[0]
'''batch_size 不能大於總數據量'''
assert self.num_data >= batch_size, \
"batch_size needs to be smaller than data size."
'''定義一個光標'''
self.cursor = -batch_size
self.batch_size = batch_size
'''最後不夠一個batch_size時的處理方法'pad' or 'discard' '''
self.last_batch_handle = last_batch_handle
_init_data 的作用是把輸入的data統一格式,因爲這個初始化data輸入支持多種類型numpy.ndarray/mxnet.ndarray/h5py.Dataset輸入,可以是單個的這些類型數據,也可能是他們的list輸入
- 比如輸入一個mxnet.ndarray
- 輸出的格式爲list[tuple(str{'_0_data'},mxnet.ndarray)]
- 如果輸入一個list:[mxnet.ndarray,mxnet.ndarray]
- 輸出格式爲list[tuple(str{'_0_data'},mxnet.ndarray),tuple(str{'_1_data'},mxnet.ndarray)]
def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
if data is None:
data = []
'''如果輸入不是list,則把data轉化爲list,list中只有一個元素'''
if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
#type(data) = list
'''接着把list轉化爲OrderedDict'''
if isinstance(data, list):
if not allow_empty:
assert(len(data) > 0)
if len(data) == 1:
'''如果list中只有一個,即輸入只有一個ndarray,
Dict 只有一個元素, key 命名和參數 default_name一致,val 爲該輸入data
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type'''
else:
'''輸入多個ndarray,則Dict中有多個元素,key命名格式爲('_%d_%s' % (i, default_name)
如:{('_0_data',ndarray),('_1_data',ndarray)……}'''
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
'''這裏把非mxnet.ndarray輸入,轉換成mxnet.ndarray數據類型'''
for k, v in data.items():
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")
'''把Dict轉成list,dict中的(key,val)變成tuple(key,val)'''
return list(sorted(data.items()))
初始化完之後再mod.fit()當中使用到該iter的幾個成員API:
- mx.io.NDArrayIter.provide_data() # 用於module 或者 executor初始化
- mx.io.NDArrayIter.provide_label() # 用於module 或者 executor初始化和上面同步
- iter(mx.io.NDArrayIter) #得到一個迭代器,用於每次訓練獲取數據batch
- mx.io.NDArrayIter.reset() #訓練時每個epoch 結束時 reset一次
@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
return [
"""
DataDesc 是一個namedtuple,不知道啥的百度去……
return DataDesc對象,初始化該對象使用了兩個信息,首先看self.data結構
self.data - list[tuple1(name1,val1),tuple2(name2,val2)……]
其中name是symbol中的輸入參數的name,上述初始化iterator時指定的
val 是該參數的數據矩陣n*v
DataDesc對象初始化時使用到 str(name) 和 tuple(batch_size,v)"""
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
for k, v in self.data
]
Module / Executor 初始化時用到mx.io.NDArrayIter
mxnet.Module.fit().Module.bind()使用到了 mx.io.NDArrayIter.provide_data() 和mx.io.NDArrayIter.provide_label()
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind)
'''解析數據描述子'''
self._data_shapes, self._label_shapes = _parse_data_desc(
self.data_names=['data'],
self.label_names=['Softmaxlabel'],
data_shapes=mx.io.NDArrayIter.provide_data(),
label_shapes=mx.io.NDArrayIter.provide_label())
'''這個數據解析器幹兩件事:
把data attributes 轉成DataDesc 格式
檢查輸入的'數據屬性表中的名字'是否和'網絡輸入symbol的名字'匹配
'''
def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
"""parse data_attrs into DataDesc format and check that names match"""
data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
_check_names_match(data_names, data_shapes, 'data', True)
if label_shapes is not None:
label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
_check_names_match(label_names, label_shapes, 'label', False)
else:
_check_names_match(label_names, [], 'label', False)
return data_shapes, label_shapes
'''然後得到的數據描述子用於Executor group類的初始化'''
self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
self._work_load_list, self._data_shapes,
self._label_shapes, self._param_names,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
grad_req=grad_req, group2ctxs=self._group2ctxs,
state_names=self._state_names)
'''初始化裏在最後一行代碼用到這些數據描述子'''#shared_group =None
self.bind_exec(data_shapes, label_shapes, shared_group)
'''繼續搜,上面的處理是爲了多GPU並行處理而設定的,在這裏把每一個GPU負責的batch,分給每一個Executor,並把這些Executor收集起來'''
self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
shared_group))
'''繼續 這裏是Module.bind()的終點,通過simple_bind得到一個Executor'''
def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
shared_data_arrays = self.shared_data_arrays[i]
input_shapes = dict(data_shapes)
if label_shapes is not None:
input_shapes.update(dict(label_shapes))
'''這裏通過輸入的data descriptor 得到一個字典 用於初始化 executor'''
input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
group2ctx = self.group2ctxs[i]
'''simple_bind
後面開另外的blog仔細研究
目前推測
這裏按照輸入數據的描述子計算網絡的靜態圖
並在對應的context上分配對應的空間'''
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
shared_exec=shared_exec, group2ctx=group2ctx,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor
Module / Executor 訓練時用到mx.io.NDArrayIter
這裏重點關注每個epoch中,fit()如何調用mx.io.NDArrayIter提供數據用於訓練的
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
nbatch = 0
'''iter()是用於把一個可迭代的非iter對象變成迭代器,如list
本blog中對train_data沒有任何改變
因爲傳入的train_data本身就是一個迭代器
'''
data_iter = iter(train_data)
'''
print(type(data_iter ),type(train_iter))
<class 'mxnet.io.NDArrayIter'> <class 'mxnet.io.NDArrayIter'>
'''
'''初始化一個標識,用於檢測iter是否到尾了'''
end_of_batch = False
'''獲得一個batch大小的data,<class 'mxnet.io.DataBatch'>
DataBatch 下有這兩個重要的成員變量
data <class 'list<class mx.NDArray>'>
label <class 'list<class mx.NDArray>'>
'''
next_data_batch = next(data_iter)
'''
def next(self):
#調用iter_next()判斷迭代器是否到尾了
if self.iter_next():
''''''
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
raise StopIteration
#self.cursor 初始化爲-self.batch_size
#這是爲了在取數據時用到的cursor指向需要取的數據,如第一次next().getdata()時cursor = 0
#每次調用next()則自加self.batch_size
#如果記錄值小於data的長度,則返回真,否則返回假
def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data
'''
while not end_of_batch:
data_batch = next_data_batch
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch)
self.update()
'''處理完參數更新後,獲取新的batch'''
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
'''except 和 raise對應'''
except StopIteration:
end_of_batch = True
重點看看iter.next().getdata()
def _getdata(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert(self.cursor < self.num_data), "DataIter needs reset."
'''判斷iter中剩下的數據是否夠一個batch'''
if self.cursor + self.batch_size <= self.num_data:
return [
# np.ndarray or NDArray case
'''data_source = self.data <class 'list<tuple<str_name,ndarray_data>>'>
取self.data list中所有成員中對應[cursor:cursor+data_batch]區間的數據
'''
x[1][self.cursor:self.cursor + self.batch_size]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[
self.cursor:self.cursor + self.batch_size])][[
list(self.idx[self.cursor:
self.cursor + self.batch_size]).index(i)
for i in sorted(self.idx[
self.cursor:self.cursor + self.batch_size])
]]) for x in data_source
]
else:
pad = self.batch_size - self.num_data + self.cursor
return [
# np.ndarray or NDArray case
concatenate([x[1][self.cursor:], x[1][:pad]])
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
concatenate([
array(x[1][sorted(self.idx[self.cursor:])][[
list(self.idx[self.cursor:]).index(i)
for i in sorted(self.idx[self.cursor:])
]]),
array(x[1][sorted(self.idx[:pad])][[
list(self.idx[:pad]).index(i)
for i in sorted(self.idx[:pad])
]])
]) for x in data_source
]
總結
- DataIter初始化時給定數據和數據名稱要和網絡輸入的symbol名稱對應
- DataIter用於網絡初始化和網絡訓練測試