在遷移學習或訓練一個新的複雜模型時,部分加載模型或部分加載模型是常見的場景。利用訓練過的參數,即使只有少數是可用的,也將有助於熱身訓練過程,並有望幫助您的模型比從頭開始訓練更快地收斂。
介紹
無論您是從缺少一些keys的部分state_dict加載的,還是加載比您加載的模型keys更多的state_dict,都可以在load_state_dict()函數中設置嚴格參數爲False,以忽略不匹配的key。在這個食譜中,我們將實驗使用不同模型的參數來預熱一個模型。
步驟
1. 導入包
2. 定義和初始化神經網絡A和B
3. 保存模型A
4. 加載模型B
1. Import necessary libraries for loading our data
import torch
import torch.nn as nn
import torch.optim as optim
2. Define and intialize the neural network A and B
class NetA(nn.Module):
def __init__(self):
super(NetA, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = NetA()
class NetB(nn.Module):
def __init__(self):
super(NetB, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netB = NetB()
3. Save model A
# Specify a path to save to
PATH = "model.pt"
torch.save(netA.state_dict(), PATH)
4. Load into model B
netB.load_state_dict(torch.load(PATH), strict=False)