用MXnet預訓練模型初始化Pytorch模型

1、MXnet符號圖:

基於MXnet所構建的符號圖是一種靜態計算圖,圖結構與內存管理都是靜態的。以Resnet50_v2爲例,Bottleneck結構的符號圖如下:

        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
        act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
        conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
                                   no_bias=True, workspace=workspace, name=name + '_conv1')
        bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
        act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
        conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),
                                   no_bias=True, workspace=workspace, name=name + '_conv2')
        bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
        act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
        conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
                                   workspace=workspace, name=name + '_conv3')
        if dim_match:
            shortcut = data
        else:
            shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
                                            workspace=workspace, name=name+'_sc')
        return conv3 + shortcut

2、加載符號圖與模型參數:

MXnet預訓練模型包括json配置文件與param參數文件:

-- resnet-50-0000.params

-- resnet-50-symbol.json

通過加載這兩個文件,便可以獲得符號圖結構、模型權重與輔助參數信息:

        prefix, index, num_layer = 'resnet-50', args.epoch, 50
        prefix = os.path.join(ROOT_PATH, "./mx_model/models/{}".format(prefix))
        symbol, param_args, param_auxs = mx.model.load_checkpoint(prefix, index)

3、Pytorch動態圖:

Pytorch是一種動態類型框架,計算圖構建與內存管理都是動態的,適合專注於研究的算法開發。按照命令式編程方式,能夠及時獲取計算圖中Tensor及其導數的數值信息。Resnet50_v2的Bottleneck結構如下:

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=False):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes, eps)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, eps)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes, eps)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        if downsample:
            self.conv_sc = nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=stride, bias=False)
        self.stride = stride

    def forward(self, input):

        out = self.bn1(input)
        out1 = self.relu(out)
        residual = input
        out = self.conv1(out1)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample:
            residual = self.conv_sc(out1)
        out += residual
        return out

4、解析MXnet參數、初始化Pytorch模型:

首先需要將MXnet參數轉爲Numpy數組形式的字典。BN層、Conv2D層、FC層解析如下:

def bn_parse(args, auxs, name, args_dict, fix_gamma=False):
    """ name0: PyTorch layer name;
        name1: MXnet layer name."""
    args_dict[name[0]] = {}
    if not fix_gamma:
        args_dict[name[0]]['running_mean'] = auxs[name[1]+'_moving_mean'].asnumpy()
        args_dict[name[0]]['running_var'] = auxs[name[1]+'_moving_var'].asnumpy()
        args_dict[name[0]]['gamma'] = args[name[1]+'_gamma'].asnumpy()
        args_dict[name[0]]['beta'] = args[name[1]+'_beta'].asnumpy()
    else:
        _mv = auxs[name[1]+'_moving_var'].asnumpy()
        _mm = auxs[name[1]+'_moving_mean'].asnumpy() - np.multiply(args[name[1]+'_beta'].asnumpy(), np.sqrt(_mv+eps))
        args_dict[name[0]]['running_mean'] = _mm
        args_dict[name[0]]['running_var'] = _mv
    return args_dict
def conv_parse(args, auxs, name, args_dict):
    """ name0: PyTorch layer name;
        name1: MXnet layer name."""
    args_dict[name[0]] = {}
    w = args[name[1]+'_weight'].asnumpy()
    args_dict[name[0]]['weight'] = w # N, M, k1, k2
    return args_dict
def fc_parse(args, auxs, name, args_dict):
    """ name0: PyTorch layer name;
        name1: MXnet layer name."""
    args_dict[name[0]] = {}
    args_dict[name[0]]['weight'] = args[name[1]+'_weight'].asnumpy()
    args_dict[name[0]]['bias'] = args[name[1]+'_bias'].asnumpy()
    return args_dict

然後逐層遍歷Pytorch的每個module,並完成模型參數賦值,從而實現用MXnet預訓練模型初始化Pytorch模型的目的:

# model initialization for PyTorch from MXnet params
class resnet(object):
    def __init__(self, name, num_layer, args, auxs, prefix='module.'):
        self.name = name
        num_stages = 4
        if num_layer == 50:
            units = [3, 4, 6, 3]
        elif num_layer == 101:
            units = [3, 4, 23, 3]
        self.num_layer = str(num_layer)
        self.param_dict = arg_parse(args, auxs, num_stages, units, prefix=prefix)

    def bn_init(self, n, m):
        if not (m.weight is None):
            m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['gamma']))
            m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['beta']))
        m.running_mean.copy_(torch.FloatTensor(self.param_dict[n]['running_mean']))
        m.running_var.copy_(torch.FloatTensor(self.param_dict[n]['running_var']))

    def conv_init(self, n, m):
        #m.weight.data.zero_()
        m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))

    def fc_init(self, n, m):
        m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))
        m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['bias']))

    def init_model(self, model):
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                self.bn_init(n, m)
            elif isinstance(m, nn.Conv2d):
                self.conv_init(n, m)
            elif isinstance(m, nn.Linear):
                self.fc_init(n, m)
        return model

5、使用MXnet的數據加載器:

mx.io.ImageRecordIter的輸出轉爲Pytorch Tensor,便可用於Pytorch模型的訓練、驗證與測試,迭代器設計如下:

def __iter__(self):
        for batch in self.data:
            nd_data = batch.data[0].asnumpy()
            nd_label = batch.label[0].asnumpy()
            input_data = torch.FloatTensor(nd_data)
            input_label = torch.LongTensor(nd_label)

            if self.cuda:
                yield input_data.cuda(non_blocking=True), input_label.cuda(non_blocking=True)
            else:
                yield input_data, input_label




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章