Pytorch學習(6) —— 加載模型部分參數的用法

上一節,我們給出了模型加載和保存的簡要示例,但是,我們有時候會用別人的參數,他們的層參數名和我們的名稱很容易不同,因此這裏將會對源碼進入深入剖析,分析參數提取和保存是如何實現的。

我們使用pytorch的VGG16預訓練模型,加載,返回其類型。可以發現,是OrderedDict類型,也就是字典類型,既然是字典,每個層的參數就是用了一個鍵值對保存起來了。

model = torch.load('vgg16-397923af.pth')
list_keys = list(model.keys()) # 將模型中的keys轉換爲list
print(type(model))
print(list_keys)
print(type(model[list_keys[0]]))
輸出:
<class 'collections.OrderedDict'>
['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
<class 'torch.Tensor'>

很明顯,這個模型數據裏面就是這個模型所需的參數,每個參數用一個鍵值對存儲,每個參數都是一個Tensor矩陣。

關於這個參數的命名,因爲這裏面的Pytorch使用一個Sequence存的,所以這個參數名的命名規則就是. Sequence變量名+第幾層+每層內部的參數名

下面我給出個例子,說明如何將這些參數拷貝到自己的模型上,下面自己寫了一個VGG模型

class VGG(nn.Module):
    def __init__(self, num_classes=100):
        super(VGG, self).__init__()
        layers = nn.ModuleList()
        in_dim = 3
        out_dim = 64
        for i in range(13):
            layers.extend([nn.Conv2d(in_dim, out_dim, 3, 1, 1),
                           nn.ReLU(inplace=True)])
            out_dim = in_dim
            if i == 1 or i == 3 or i == 6 or i == 9 or i == 12:
                layers.append(nn.MaxPool2d(2,2))
                if i != 9:
                    out_dim *= 2
        self.fea = nn.Sequential(layers)
        self.cls = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, num_classes),
            nn.Softmax(num_classes)
        )

        def forward(self, x):
            x = self.fea(x)
            x = x.view(x.size(0), -1)
            x = self.cls(x)
            return x

實例化一個VGG,並輸出這個模型所含的參數

vgg = VGG()
print(list(vgg.state_dict().keys()))
輸出:
['fea.0.0.weight', 'fea.0.0.bias', 'fea.0.2.weight', 'fea.0.2.bias', 'fea.0.5.weight', 'fea.0.5.bias', 'fea.0.7.weight', 'fea.0.7.bias', 'fea.0.10.weight', 'fea.0.10.bias', 'fea.0.12.weight', 'fea.0.12.bias', 'fea.0.14.weight', 'fea.0.14.bias', 'fea.0.17.weight', 'fea.0.17.bias', 'fea.0.19.weight', 'fea.0.19.bias', 'fea.0.21.weight', 'fea.0.21.bias', 'fea.0.24.weight', 'fea.0.24.bias', 'fea.0.26.weight', 'fea.0.26.bias', 'fea.0.28.weight', 'fea.0.28.bias', 'cls.0.weight', 'cls.0.bias', 'cls.3.weight', 'cls.3.bias', 'cls.6.weight', 'cls.6.bias']

根據之前的博客Pytorch學習(2) —— 網絡工具箱 TORCH.NN 基本類用法,我們使用load_state_dict進行模型加載

比如用下面的方法可以將另一個模型的參數轉到自己的參數上,記住strict一定要設置爲false,否則會出錯。

vgg.load_state_dict({'fea.0.0.weight':model['features.0.weight']}, strict=False)

總結

本部分介紹瞭如何將預訓練的模型參數加載到自己的模型上,有時候我們的網絡參數是由兩個其他網絡構成,那麼本部分提供了一種加載方法。

至此,模型的加載用法已經完成,下面就開始介紹如何構建模型。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章