Pytorch学习(七)跨设备保存和加载模型

在某些情况下,您可能需要在不同的设备上保存和加载您的神经网络。

介绍

使用PyTorch在不同设备之间保存和加载模型相对简单。在本菜谱中,我们将尝试跨cpu和gpu保存和加载模型。

步骤

1. 导入包

2. 定义和初始化神经网络

3. 在GPU上保存,在CPU上加载

4. 在GPU上保存,在GPU上加载

5. 在CPU上保存,在GPU上加载

6. 保存和加载DataParallel模型

1. Import necessary libraries for loading our data

import torch
import torch.nn as nn
import torch.optim as optim

2. Define and intialize the neural network

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

3. Save on GPU, Load on CPU

当在CPU上加载一个经过GPU训练的模型时,将torch.device(' cpu ')传递给torch.load()函数中的map_location参数。

# Specify a path to save to
PATH = "model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device('cpu')
model = Net()
model.load_state_dict(torch.load(PATH, map_location=device))

在这种情况下,使用map_location参数将张量下面的存储动态地重新映射到CPU设备。

4. Save on GPU, Load on GPU

当在GPU上加载一个经过训练和保存的模型时,只需使用model.to(torch.device(' CUDA '))将初始化的模型转换为CUDA优化的模型。

确保用.to(torch.device('cuda'))函数。用于为模型准备数据的所有模型输入。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load(PATH))
model.to(device)

注意:调用my_tensor.to(device)返回一个GPU上my_tensor的副本。它不会重写my_tensor。所以,记住要手动重写tensors:

my_tensor = my_tensor.to(torch.device('cuda')).

5. Save on CPU, Load on GPU

当在GPU上加载一个经过训练并保存在CPU上的模型时,将torch.load()函数中的map_location参数设置为cuda:device_id。这将加载模型到给定的GPU设备。

一定要调用model.to(torch.device('cuda'))来将模型的参数张量转换为cuda张量。

最后,确保所有模型输入都使用.to(torch.device('cuda'))功能,为cuda优化的模型准备数据。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
# Choose whatever GPU device number you want
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
model.to(device)

6. Saving torch.nn.DataParallel Models

torch.nn.DataParallel是一种能够实现并行GPU利用的模型包装器。

要一般地保存一个DataParallel模型,请保存model.module.state_dict()。通过这种方式,您可以灵活地以您想要的任何方式加载模型到您想要的任何设备。

# Save
torch.save(net.module.state_dict(), PATH)

# Load to whatever device you want

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章