Pytorch學習(六)在PyTorch中使用不同模型的參數來預熱啓動模型

在遷移學習或訓練一個新的複雜模型時,部分加載模型或部分加載模型是常見的場景。利用訓練過的參數,即使只有少數是可用的,也將有助於熱身訓練過程,並有望幫助您的模型比從頭開始訓練更快地收斂。

介紹

無論您是從缺少一些keys的部分state_dict加載的,還是加載比您加載的模型keys更多的state_dict,都可以在load_state_dict()函數中設置嚴格參數爲False,以忽略不匹配的key。在這個食譜中,我們將實驗使用不同模型的參數來預熱一個模型。

步驟

1. 導入包

2. 定義和初始化神經網絡A和B

3. 保存模型A

4. 加載模型B

1. Import necessary libraries for loading our data

import torch
import torch.nn as nn
import torch.optim as optim

2. Define and intialize the neural network A and B

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

3. Save model A

# Specify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

4. Load into model B

netB.load_state_dict(torch.load(PATH), strict=False)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章