動手學gluon系列之讀取預訓練模型----多種方法讀取預訓練模型進行finetune

本文主要是博主學習gluon時候的一些總結,共勉,如有錯誤,歡迎指正

gluon主要有3個方法得到預訓練模型:

  • gluon自身的model_zoo
  • gluoncv提供的model_zoo
  • mxnet提供的預訓練模型(.params ,.json)

下面分別就這三個方面進行介紹


一:讀取gluon model_zoo提供的模型,並進行finetune

gluon提供的model主要在gluon.model_zoo.vision下,模型地址:https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html,你可以根據自己的情況查找對應的模型進行使用。model_zoo提供的模型均爲features+output結構

調用方法如下:

一:只修改最終的fc層,進行finetune:

from mxnet import gluon
class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]

finetune_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)

with finetune_net.name_scope():
    finetune_net.output = nn.Dense(class_num)
finetune_net.output.initialize(init=mx.init.Xavier(),ctx=ctx)
finetune_net.hybridize()

二:不僅僅修改最終的fc層,還可以增加幾層

下面的方法,首先提取出features,然後構建增加的sequential,最後將兩部分通過sequential合併在一起。

from mxnet import gluon
pretrained_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)
pretrained_net_features = pretrained_net.features

class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]
modify_net = nn.HybridSequential(prefix="")
with modify_net.name_scope():
    modify_net.add(nn.Dense(128,activation='relu'),
                nn.Dropout(0.5),
                nn.Dense(class_num))
    modify_net.collect_params().initialize(init=mx.init.Xavier(),ctx=ctx)
    
net = nn.HybridSequential(prefix="")
with modify_net.name_scope():
    net.add(pretrained_net_features)
    net.add(modify_net)
net.hybridize() ## 該語句代表靜態圖動態圖切換。

也可以直接修改features,達到同樣的效果,不過記得初始化

from mxnet import gluon
class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]

finetune_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)

with finetune_net.name_scope():
    finetune_net.features.add(nn.Dense(128,activation='relu'),
                              nn.Dropout(0.5))
    finetune_net.output = nn.Dense(class_num)
finetune_net.features.initialize(init=mx.init.Xavier(),force_init=False,ctx=ctx)
finetune_net.output.initialize(init=mx.init.Xavier(),ctx=ctx)
finetune_net.hybridize()

二:讀取gluoncv model_zoo提供的模型,並進行finetune(推薦)

gluoncv是gluon提供的比較強大的視覺庫,其中提供了很多的預訓練模型可以使用,鏈接:https://gluon-cv.mxnet.io/model_zoo/classification.html

使用gluoncv的預訓練模型也很方便,跟使用gluon的model_zoo方法基本一致,不同點如下:

from gluoncv.model_zoo import get_model

finetune_net = get_model('ResNet50_v2', pretrained=True)

其他的就跟上面的一致了。
注意,有個gluoncv模型不是feature,output結構的,所以在使用的時候,可以看一下其結構,靈活判斷

三、直接讀取mxnet模型( .params + .json)

有的時候,我們可能需要利用gluon讀取mxnet模型,目前利用gluon讀取mxnet模型,只能使用gluon.nn.SymbolBlock()進行讀取,如下:

ctx = mx.gpu(0)
sym, arg_params, aux_params = mx.model.load_checkpoint('../model/resnetv1d-101',17) ## model path and model index
internals = sym.get_internals()
net_out = internals['fc1_output']

net = gluon.nn.SymbolBlock(outputs=net_out, inputs=mx.sym.var('data'))

net.load_params(filename='../model/resnetv1d-101-0017.params', ctx=ctx)

如上,我們便讀取了mxnet的model,現在我們便可以對net進行操作了,如下代碼構建了一個3分類的網絡:


class_num = 3
finetune_net = nn.HybridSequential(prefix="")
with finetune_net.name_scope():
    finetune_net.add(net)
    finetune_net.add(nn.Dense(class_num))## 輸出3分類
net.hybridize() ## 該語句代表靜態圖動態圖切換。

四、最優雅的方式,重新定義網絡,實現任意的操作,:

這種方法最爲優雅,也最靈活,你可以採用上面個各個方法讀取模型,然後重寫forward,實現網絡的任意操作

class PretrainedNetwork(gluon.HybridBlock):
    def __init__(self, pretrained_layer, **kwargs):
        super(PretrainedNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.pretrained_layer = pretrained_layer 
            self.fc = nn.HybridSequential()
            self.fc.add(
                        nn.Flatten(),
                        nn.Dense(256, activation = 'relu'),
                        nn.Dropout(rate = 0.5),
                        nn.Dense(128)
                        )
            self.output = nn.Dense(2)

            
    def hybrid_forward(self, F, x):  ## 這裏注意F不要忘記。
        x = self.pretrained_layer(x)
        x = self.fc(x)
        out = self.output(x)
     
        return out
        
        
        
### 採用如下得到網絡:

from gluoncv.model_zoo import get_model

finetune_net = get_model('ResNet50_v2', pretrained=True)    
net = PretrainedNetwork(pretrained_layer = finetune_net)
net.initialize(forece_reinit = False, init = init.Xavier()) ## 初始化


至此,應該常用的利用預訓練模型進行finetune的方法都包含了,如果還有更新,歡迎討論

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章