上一節,我們給出了模型加載和保存的簡要示例,但是,我們有時候會用別人的參數,他們的層參數名和我們的名稱很容易不同,因此這裏將會對源碼進入深入剖析,分析參數提取和保存是如何實現的。
我們使用pytorch的VGG16預訓練模型,加載,返回其類型。可以發現,是OrderedDict類型,也就是字典類型,既然是字典,每個層的參數就是用了一個鍵值對保存起來了。
model = torch.load('vgg16-397923af.pth')
list_keys = list(model.keys()) # 將模型中的keys轉換爲list
print(type(model))
print(list_keys)
print(type(model[list_keys[0]]))
輸出:
<class 'collections.OrderedDict'>
['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
<class 'torch.Tensor'>
很明顯,這個模型數據裏面就是這個模型所需的參數,每個參數用一個鍵值對存儲,每個參數都是一個Tensor矩陣。
關於這個參數的命名,因爲這裏面的Pytorch使用一個Sequence存的,所以這個參數名的命名規則就是. Sequence變量名+第幾層+每層內部的參數名
下面我給出個例子,說明如何將這些參數拷貝到自己的模型上,下面自己寫了一個VGG模型
class VGG(nn.Module):
def __init__(self, num_classes=100):
super(VGG, self).__init__()
layers = nn.ModuleList()
in_dim = 3
out_dim = 64
for i in range(13):
layers.extend([nn.Conv2d(in_dim, out_dim, 3, 1, 1),
nn.ReLU(inplace=True)])
out_dim = in_dim
if i == 1 or i == 3 or i == 6 or i == 9 or i == 12:
layers.append(nn.MaxPool2d(2,2))
if i != 9:
out_dim *= 2
self.fea = nn.Sequential(layers)
self.cls = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),
nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),
nn.Linear(4096, num_classes),
nn.Softmax(num_classes)
)
def forward(self, x):
x = self.fea(x)
x = x.view(x.size(0), -1)
x = self.cls(x)
return x
實例化一個VGG,並輸出這個模型所含的參數
vgg = VGG()
print(list(vgg.state_dict().keys()))
輸出:
['fea.0.0.weight', 'fea.0.0.bias', 'fea.0.2.weight', 'fea.0.2.bias', 'fea.0.5.weight', 'fea.0.5.bias', 'fea.0.7.weight', 'fea.0.7.bias', 'fea.0.10.weight', 'fea.0.10.bias', 'fea.0.12.weight', 'fea.0.12.bias', 'fea.0.14.weight', 'fea.0.14.bias', 'fea.0.17.weight', 'fea.0.17.bias', 'fea.0.19.weight', 'fea.0.19.bias', 'fea.0.21.weight', 'fea.0.21.bias', 'fea.0.24.weight', 'fea.0.24.bias', 'fea.0.26.weight', 'fea.0.26.bias', 'fea.0.28.weight', 'fea.0.28.bias', 'cls.0.weight', 'cls.0.bias', 'cls.3.weight', 'cls.3.bias', 'cls.6.weight', 'cls.6.bias']
根據之前的博客Pytorch學習(2) —— 網絡工具箱 TORCH.NN 基本類用法,我們使用load_state_dict進行模型加載
比如用下面的方法可以將另一個模型的參數轉到自己的參數上,記住strict一定要設置爲false,否則會出錯。
vgg.load_state_dict({'fea.0.0.weight':model['features.0.weight']}, strict=False)
總結
本部分介紹瞭如何將預訓練的模型參數加載到自己的模型上,有時候我們的網絡參數是由兩個其他網絡構成,那麼本部分提供了一種加載方法。
至此,模型的加載用法已經完成,下面就開始介紹如何構建模型。