pytorch模型轉mxnet

介紹

gluon把mxnet再進行封裝,封裝的風格非常接近pytorch

使用gluon的好處是非常容易把pytorch模型向mxnet轉化

唯一的問題是gluon封裝還不成熟,封裝好的layer不多,很多常用的layer 如concat,upsampling等layer都沒有

這裏關注如何把pytorch 模型快速轉換成 mxnet基於symbol 和 exector設計的網絡

pytorch轉mxnet module

關鍵點:

  • mxnet 設計網絡時symbol 名稱要和pytorch初始化中各網絡層名稱對應 
  • torch.load()讀入pytorch模型checkpoint 字典,取當中的'state_dict'元素,也是一個字典
  • pytorch state_dict 字典中key是網絡層參數的名稱,val是參數ndarray
  • pytorch 的參數名稱的組織形式和mxnet一樣,但是連接符號不同,pytorch是'.',而mxnet是'_'比如:

pytorch '0.conv1.0.weight'

mxnet  '0_conv1_0_weight'

  • pytorch 的參數array 和mxnet 的參數array 完全一樣,只要名稱對上,直接賦值即可初始化mxnet模型

需要做的有以下幾點:

  • 設計和pytorch網絡對應的mxnet網絡
  • 加載pytorch checkpoint
  • 調整pytorch checkpoint state_dict 的key名稱和mxnet命名格式一致

FlowNet2S PytorchToMxnet

pytorch flownet2S 的checkpoint 可以在github上搜到

import mxnet as mx
from symbol_util import *
import pickle

def get_loss(data, label, loss_scale, name, get_input=False, is_sparse = False, type='stereo'):

    if type == 'stereo':
        data = mx.sym.Activation(data=data, act_type='relu',name=name+'relu')
    # loss
    if  is_sparse:
        loss =mx.symbol.Custom(data=data, label=label, name=name, loss_scale= loss_scale, is_l1=True,
            op_type='SparseRegressionLoss')
    else:
        loss = mx.sym.MAERegressionOutput(data=data, label=label, name=name, grad_scale=loss_scale)
    return (loss,data) if get_input else loss


def flownet_s(loss_scale, is_sparse=False, name=''):
    img1 = mx.symbol.Variable('img1')
    img2 = mx.symbol.Variable('img2')
    data = mx.symbol.concat(img1,img2,dim=1)
    labels = {'loss{}'.format(i): mx.sym.Variable('loss{}_label'.format(i)) for i in range(0, 7)}
    # print('labels: ',labels)
    prediction = {}# a dict for loss collection
    loss = []#a list

    #normalize
    data = (data-125)/255

    # extract featrue
    conv1 = mx.sym.Convolution(data, pad=(3, 3), kernel=(7, 7), stride=(2, 2), num_filter=64, name=name + 'conv1_0')
    conv1 = mx.sym.LeakyReLU(data=conv1, act_type='leaky', slope=0.1)

    conv2 = mx.sym.Convolution(conv1, pad=(2, 2), kernel=(5, 5), stride=(2, 2), num_filter=128, name=name + 'conv2_0')
    conv2 = mx.sym.LeakyReLU(data=conv2, act_type='leaky', slope=0.1)

    conv3a = mx.sym.Convolution(conv2, pad=(2, 2), kernel=(5, 5), stride=(2, 2), num_filter=256, name=name + 'conv3_0')
    conv3a = mx.sym.LeakyReLU(data=conv3a, act_type='leaky', slope=0.1)

    conv3b = mx.sym.Convolution(conv3a, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=256, name=name + 'conv3_1_0')
    conv3b = mx.sym.LeakyReLU(data=conv3b, act_type='leaky', slope=0.1)

    conv4a = mx.sym.Convolution(conv3b, pad=(1, 1), kernel=(3, 3), stride=(2, 2), num_filter=512, name=name + 'conv4_0')
    conv4a = mx.sym.LeakyReLU(data=conv4a, act_type='leaky', slope=0.1)

    conv4b = mx.sym.Convolution(conv4a, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=512, name=name + 'conv4_1_0')
    conv4b = mx.sym.LeakyReLU(data=conv4b, act_type='leaky', slope=0.1)

    conv5a = mx.sym.Convolution(conv4b, pad=(1, 1), kernel=(3, 3), stride=(2, 2), num_filter=512, name=name + 'conv5_0')
    conv5a = mx.sym.LeakyReLU(data=conv5a, act_type='leaky', slope=0.1)

    conv5b = mx.sym.Convolution(conv5a, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=512, name=name + 'conv5_1_0')
    conv5b = mx.sym.LeakyReLU(data=conv5b, act_type='leaky', slope=0.1)

    conv6a = mx.sym.Convolution(conv5b, pad=(1, 1), kernel=(3, 3), stride=(2, 2), num_filter=1024, name=name + 'conv6_0')
    conv6a = mx.sym.LeakyReLU(data=conv6a, act_type='leaky', slope=0.1)

    conv6b = mx.sym.Convolution(conv6a, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=1024,
                                name=name + 'conv6_1_0')
    conv6b = mx.sym.LeakyReLU(data=conv6b, act_type='leaky', slope=0.1, )

    #predict flow
    pr6 = mx.sym.Convolution(conv6b, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=2,
                             name=name + 'predict_flow6')
    prediction['loss6'] = pr6

    upsample_pr6to5 = mx.sym.Deconvolution(pr6, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=2,
                                           name=name + 'upsampled_flow6_to_5', no_bias=True)
    upconv5 = mx.sym.Deconvolution(conv6b, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=512,
                                   name=name + 'deconv5_0', no_bias=False)
    upconv5 = mx.sym.LeakyReLU(data=upconv5, act_type='leaky', slope=0.1)
    iconv5 = mx.sym.Concat(conv5b, upconv5, upsample_pr6to5, dim=1)


    pr5 = mx.sym.Convolution(iconv5, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=2,
                             name=name + 'predict_flow5')
    prediction['loss5'] = pr5

    upconv4 = mx.sym.Deconvolution(iconv5, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=256,
                                   name=name + 'deconv4_0', no_bias=False)
    upconv4 = mx.sym.LeakyReLU(data=upconv4, act_type='leaky', slope=0.1)

    upsample_pr5to4 = mx.sym.Deconvolution(pr5, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=2,
                                           name=name + 'upsampled_flow5_to_4', no_bias=True)

    iconv4 = mx.sym.Concat(conv4b, upconv4, upsample_pr5to4)

    pr4 = mx.sym.Convolution(iconv4, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=2,
                             name=name + 'predict_flow4')
    prediction['loss4'] = pr4

    upconv3 = mx.sym.Deconvolution(iconv4, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=128,
                                   name=name + 'deconv3_0', no_bias=False)
    upconv3 = mx.sym.LeakyReLU(data=upconv3, act_type='leaky', slope=0.1)

    upsample_pr4to3 = mx.sym.Deconvolution(pr4, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=2,
                                           name= name + 'upsampled_flow4_to_3', no_bias=True)
    iconv3 = mx.sym.Concat(conv3b, upconv3, upsample_pr4to3)

    pr3 = mx.sym.Convolution(iconv3, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=2,
                             name=name + 'predict_flow3')
    prediction['loss3'] = pr3

    upconv2 = mx.sym.Deconvolution(iconv3, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=64,
                                   name=name + 'deconv2_0', no_bias=False)
    upconv2 = mx.sym.LeakyReLU(data=upconv2, act_type='leaky', slope=0.1)

    upsample_pr3to2 = mx.sym.Deconvolution(pr3, pad=(1, 1), kernel=(4, 4), stride=(2, 2), num_filter=2,
                                           name=name + 'upsampled_flow3_to_2', no_bias=True)
    iconv2 = mx.sym.Concat(conv2, upconv2, upsample_pr3to2)

    pr2 = mx.sym.Convolution(iconv2, pad=(1, 1), kernel=(3, 3), stride=(1, 1), num_filter=2,
                             name=name + 'predict_flow2')
    prediction['loss2'] = pr2
    flow = mx.sym.UpSampling(arg0=pr2,scale=4,num_filter=2,num_args = 1,sample_type='nearest', name='upsample_flow2_to_1')
    # ignore the loss functions with loss scale of zero
    keys = loss_scale.keys()
    # keys.sort()
    #obtain the symbol of the losses
    for key in keys:
        # loss.append(get_loss(prediction[key] * 20, labels[key], loss_scale[key], name=key + name,get_input=False, is_sparse=is_sparse, type='flow'))
        loss.append(mx.sym.MAERegressionOutput(data=prediction[key] * 20, label=labels[key], name=key + name, grad_scale=loss_scale[key]))
    # print('loss:  ',loss)
    #group 暫時不知道爲嘛要group
    loss_group =mx.sym.Group(loss)
    # print('net:  ',loss_group)
    return loss_group,flow

import gluonbook as gb
import torch
from utils.frame_utils import *
import numpy as np
if __name__ == '__main__':
    checkpoint = torch.load("C:/Users/junjie.huang/PycharmProjects/flownet2_mxnet/flownet2_pytorch/FlowNet2-S_checkpoint.pth.tar")
    # # checkpoint是一個字典
    print(isinstance(checkpoint['state_dict'], dict))
    # # 打印checkpoint字典中的key名
    print('keys of checkpoint:')
    for i in checkpoint:
        print(i)
    print('')
    # # pytorch 模型參數保存在一個key名爲'state_dict'的元素中
    state_dict = checkpoint['state_dict']
    # # state_dict也是一個字典
    print('keys of state_dict:')
    for i in state_dict:
        print(i)
        # print(state_dict[i].size())
    print('')
    # print(state_dict)
    #字典的value是torch.tensor
    print(torch.is_tensor(state_dict['conv1.0.weight']))
    #查看某個value的size
    print(state_dict['conv1.0.weight'].size())

    #flownet-mxnet init
    loss_scale={'loss2': 1.00,
               'loss3': 1.00,
               'loss4': 1.00,
               'loss5': 1.00,
               'loss6': 1.00}
    loss,flow = flownet_s(loss_scale=loss_scale,is_sparse=False)
    print('loss information: ')
    print('loss:',loss)
    print('type:',type(loss))
    print('list_arguments:',loss.list_arguments())
    print('list_outputs:',loss.list_outputs())
    print('list_inputs:',loss.list_inputs())
    print('')

    print('flow information: ')
    print('flow:',flow)
    print('type:',type(flow))
    print('list_arguments:',flow.list_arguments())
    print('list_outputs:',flow.list_outputs())
    print('list_inputs:',flow.list_inputs())
    print('')
    name_mxnet = symbol.list_arguments()
    print(type(name_mxnet))
    for key in name_mxnet:
        print(key)

    name_mxnet.sort()
    for key in name_mxnet:
        print(key)
    print(name_mxnet)

    shapes = (1, 3, 384, 512)
    ctx = gb.try_gpu()
    # exe = symbol.simple_bind(ctx=ctx, img1=shapes,img2=shapes)
    exe = flow.simple_bind(ctx=ctx, img1=shapes, img2=shapes)
    print('exe type: ',type(exe))
    print('exe:  ',exe)
    #module
    # mod = mx.mod.Module(flow)
    # print('mod type: ', type(exe))
    # print('mod:  ', exe)

    pim1 = read_gen("C:/Users/junjie.huang/PycharmProjects/flownet2_mxnet/data/0000007-img0.ppm")
    pim2 = read_gen("C:/Users/junjie.huang/PycharmProjects/flownet2_mxnet/data/0000007-img1.ppm")
    print(pim1.shape)

    '''使用pytorch 的state_dict 初始化 mxnet 模型參數'''
    for key in state_dict:
        # print(type(key))
        k_split = key.split('.')
        key_mx = '_'.join(k_split)
        # print(key,key_mx)
        try:
            exe.arg_dict[key_mx][:]=state_dict[key].data
        except:
            print(key,exe.arg_dict[key_mx].shape,state_dict[key].data.shape)

    
    exe.arg_dict['img1'][:] = pim1[np.newaxis, :, :, :].transpose(0, 3, 1, 2).data
    exe.arg_dict['img2'][:] = pim2[np.newaxis, :, :, :].transpose(0, 3, 1, 2).data

    result = exe.forward()
    print('result:  ',type(result))
    # for tmp in result:
    #     print(type(tmp))
    #     print(tmp.shape)
    # color = flow2color(exe.outputs[0].asnumpy()[0].transpose(1, 2, 0))
    outputs = exe.outputs
    print('output type:  ',type(outputs))
    # for tmp in outputs:
    #     print(type(tmp))
    #     print(tmp.shape)

    #來自pytroch flownet2
    from visualize import flow2color
    # color = flow2color(exe.outputs[0].asnumpy()[0].transpose(1,2,0))
    flow_color = flow2color(exe.outputs[0].asnumpy()[0].transpose(1, 2, 0))
    print('color type:',type(flow_color))
    import matplotlib.pyplot as plt
    #來自pytorch
    from torchvision.transforms import ToPILImage
    TF = ToPILImage()
    images = TF(flow_color)
    images.show()
    # plt.imshow(color)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章