整合Pytorch和MNN的嵌入式部署流程

https://zhuanlan.zhihu.com/p/76605363

https://zhuanlan.zhihu.com/p/76605363

https://zhuanlan.zhihu.com/p/76605363

工程的完整鏈接可以參考Github鏈接

Pytorch以其動態圖的調用方式,深得許多科研人員的喜愛,是許多人進行科研研究、算法預研的不二之選。本文我們跟大家討論一下,如何使用Pytorch來進行嵌入式的算法部署。這裏我們採用的離線訓練框架爲Pytorch,嵌入式端的推理框架爲阿里巴巴近期開源的高性能推理框架MNN。下面我們將結合MNIST這個簡單的分類任務來跟大家一步一步的完成嵌入式端的部署。

Pytorch的模型不能直接被MNN進行解析,所以我們這裏需要選定一個媒介。參考之前專欄的一篇文章《整合mxnet和MNN的嵌入式部署流程》,這裏也採用ONNX進行pytorch和MNN之間的橋樑。

  1. 模型的設計

模型的設計與《整合mxnet和MNN的嵌入式部署流程》文中的模型設計基本一樣,大家可以看下面代碼:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
import torch.optim as optim

class MNIST(nn.Module):

    def __init__(self):
        super(MNIST, self).__init__()
        self.conv0  = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False)
        self.bn0    = nn.BatchNorm2d(num_features=20)
        self.relu0  = nn.ReLU()
        self.maxp0  = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv1  = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False)
        self.bn1    = nn.BatchNorm2d(num_features=50)
        self.relu1  = nn.ReLU()
        self.maxp1  = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv2  = nn.Conv2d(in_channels=50, out_channels=500, kernel_size=4, stride=1, bias=False)
        self.bn2    = nn.BatchNorm2d(num_features=500)
        self.relu2  = nn.ReLU()
        self.conv3  = nn.Conv2d(in_channels=500, out_channels=10, kernel_size=1, stride=1, bias=False)
        # self.dense2 = nn.Linear(in_features=400, out_features=120, bias=False)
        # self.dp2    = nn.Dropout(p=0.5)
        # self.relu2  = nn.ReLU()
        # self.dense3 = nn.Linear(in_features=120, out_features=10, bias=False)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.maxp0(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxp1(x)

        # x = x.view(-1, self.num_flat_features(x))
        # x = self.dense2(x)
        # x = self.dp2(x)
        # x = self.relu2(x)
        # x = self.dense3(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = torch.squeeze(x)
        return x

大家可以注意一下上述代碼裏的註釋部分,這裏我們進行一下介紹。由於Pytorch在實現矩陣乘法的時候,需要使用view進行數據的拉平,然後再進行matmul的操作,這幾個操作MNN並沒有進行支持(我第一次的實現是使用註釋部分的代碼,然後MNNConvert的時候報的錯誤,不支持一些op)。所以我們這裏採用了一個4*4和1*1的con2d來代替全連接層。

2. 導出ONNX模型

我們需要導出另外一種可以被MNN解析的模型格式,這裏我們選擇的是ONNX。如下爲導出ONNX的腳本文件:

import torch
import torch.nn as nn
import torch.onnx
from train_mnist import MNIST

# A model class instance (class not shown)
model = MNIST()

# Load the weights from a file (.pth usually)
weights_path = './mnist.pth'
state_dict = torch.load(weights_path)

# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

# Create the right input shape (e.g. for an image)
input = torch.randn(1, 1, 28, 28)

torch.onnx.export(model, input, "mnist.onnx", verbose=True)


import onnx

# Load the ONNX model
model = onnx.load("mnist.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)

3. 導出MNN模型

MNN提供了轉換ONNX到MNN模型的工具,執行如下腳本即可,關於MNN轉換工具編譯可以參考Model Conversion。下面是轉換腳本:

./MNNConvert -f ONNX --modelFile mnist.onnx --MNNModel mnist.mnn --bizCode MNN

輸出的結果如下:

Start to Convert Other Model Format To MNN Model...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/onnx/onnxConverter.cpp:29: ONNX Model ir version: 3
Start to Optimize the MNN Net...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:44: Inputs: 0
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:54: Outputs: 32, Type = Squeeze
Converted Done!

可以看出採用的ONNX IR版本爲3,輸入的節點名字爲0,輸出節點名字爲32.

4. 在線部署

在線部署流程在這裏,爲使用MNN加載解析好的mnn模型參數進行inference等一系列業務操作。關於如何在android上面使用mnn進行部署,本專欄已經有好幾篇介紹的文章,這裏就不進行贅述了。完整的JNI業務代碼可以參考如下鏈接JNI 業務代碼

  • 最後

選取的樣例爲簡單的mnist,雖然全連接層的實現在轉換過程中有一些小問題,但是我們修改了網絡結構,採用一個4x4和一個1x1的conv2d來進行替代,解決了模型轉換的問題。另外歡迎大家留言討論、關注本專欄及公衆號,謝謝大家!

發佈了229 篇原創文章 · 獲贊 169 · 訪問量 39萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章