Pytorch學習(七)跨設備保存和加載模型

在某些情況下,您可能需要在不同的設備上保存和加載您的神經網絡。

介紹

使用PyTorch在不同設備之間保存和加載模型相對簡單。在本菜譜中,我們將嘗試跨cpu和gpu保存和加載模型。

步驟

1. 導入包

2. 定義和初始化神經網絡

3. 在GPU上保存,在CPU上加載

4. 在GPU上保存,在GPU上加載

5. 在CPU上保存,在GPU上加載

6. 保存和加載DataParallel模型

1. Import necessary libraries for loading our data

import torch
import torch.nn as nn
import torch.optim as optim

2. Define and intialize the neural network

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

3. Save on GPU, Load on CPU

當在CPU上加載一個經過GPU訓練的模型時,將torch.device(' cpu ')傳遞給torch.load()函數中的map_location參數。

# Specify a path to save to
PATH = "model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device('cpu')
model = Net()
model.load_state_dict(torch.load(PATH, map_location=device))

在這種情況下,使用map_location參數將張量下面的存儲動態地重新映射到CPU設備。

4. Save on GPU, Load on GPU

當在GPU上加載一個經過訓練和保存的模型時,只需使用model.to(torch.device(' CUDA '))將初始化的模型轉換爲CUDA優化的模型。

確保用.to(torch.device('cuda'))函數。用於爲模型準備數據的所有模型輸入。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load(PATH))
model.to(device)

注意:調用my_tensor.to(device)返回一個GPU上my_tensor的副本。它不會重寫my_tensor。所以,記住要手動重寫tensors:

my_tensor = my_tensor.to(torch.device('cuda')).

5. Save on CPU, Load on GPU

當在GPU上加載一個經過訓練並保存在CPU上的模型時,將torch.load()函數中的map_location參數設置爲cuda:device_id。這將加載模型到給定的GPU設備。

一定要調用model.to(torch.device('cuda'))來將模型的參數張量轉換爲cuda張量。

最後,確保所有模型輸入都使用.to(torch.device('cuda'))功能,爲cuda優化的模型準備數據。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
# Choose whatever GPU device number you want
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
model.to(device)

6. Saving torch.nn.DataParallel Models

torch.nn.DataParallel是一種能夠實現並行GPU利用的模型包裝器。

要一般地保存一個DataParallel模型,請保存model.module.state_dict()。通過這種方式,您可以靈活地以您想要的任何方式加載模型到您想要的任何設備。

# Save
torch.save(net.module.state_dict(), PATH)

# Load to whatever device you want

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章